diff options
35 files changed, 884 insertions, 729 deletions
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index fae3b6783..b2206900b 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -25,7 +25,6 @@ go_library( "//pkg/marshal", "//pkg/marshal/primitive", "//pkg/metric", - "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", "//pkg/sentry/fs", diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index fe11fca9c..dcf898c0a 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -28,9 +28,9 @@ import ( "bytes" "fmt" "io" + "io/ioutil" "math" "reflect" - "sync/atomic" "syscall" "time" @@ -43,7 +43,6 @@ import ( "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/metric" - "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" @@ -308,16 +307,8 @@ type socketOpsCommon struct { skType linux.SockType protocol int - // readViewHasData is 1 iff readView has data to be read, 0 otherwise. - // Must be accessed using atomic operations. It must only be written - // with readMu held but can be read without holding readMu. The latter - // is required to avoid deadlocks in epoll Readiness checks. - readViewHasData uint32 - // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` - // readView contains the remaining payload from the last packet. - readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. readCM socket.IPControlMessages @@ -336,8 +327,8 @@ type socketOpsCommon struct { // valid when timestampValid is true. It is protected by readMu. timestampNS int64 - // sockOptInq corresponds to TCP_INQ. It is implemented at this level - // because it takes into account data from readView. + // TODO(b/153685824): Move this to SocketOptions. + // sockOptInq corresponds to TCP_INQ. sockOptInq bool } @@ -377,41 +368,23 @@ func (s *socketOpsCommon) isPacketBased() bool { return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW } -// fetchReadView updates the readView field of the socket if it's currently -// empty. It assumes that the socket is locked. -// // Precondition: s.readMu must be held. -func (s *socketOpsCommon) fetchReadView() *syserr.Error { - if len(s.readView) > 0 { - return nil - } - s.readView = nil - s.sender = tcpip.FullAddress{} - s.linkPacketInfo = tcpip.LinkPacketInfo{} +func (s *socketOpsCommon) readLocked(dst io.Writer, count int, peek bool) (numRead, numTotal int, serr *syserr.Error) { + res, err := s.Endpoint.Read(dst, count, tcpip.ReadOptions{ + Peek: peek, + NeedRemoteAddr: true, + NeedLinkPacketInfo: true, + }) - var v buffer.View - var cms tcpip.ControlMessages - var err *tcpip.Error + // Assign these anyways. + s.readCM = socket.NewIPControlMessages(s.family, res.ControlMessages) + s.sender = res.RemoteAddr + s.linkPacketInfo = res.LinkPacketInfo - switch e := s.Endpoint.(type) { - // The ordering of these interfaces matters. The most specific - // interfaces must be specified before the more generic Endpoint - // interface. - case tcpip.PacketEndpoint: - v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo) - case tcpip.Endpoint: - v, cms, err = e.Read(&s.sender) - } if err != nil { - atomic.StoreUint32(&s.readViewHasData, 0) - return syserr.TranslateNetstackError(err) + return 0, 0, syserr.TranslateNetstackError(err) } - - s.readView = v - s.readCM = socket.NewIPControlMessages(s.family, cms) - atomic.StoreUint32(&s.readViewHasData, 1) - - return nil + return res.Count, res.Total, nil } // Release implements fs.FileOperations.Release. @@ -460,38 +433,14 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS // WriteTo implements fs.FileOperations.WriteTo. func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { s.readMu.Lock() + defer s.readMu.Unlock() - // Copy as much data as possible. - done := int64(0) - for count > 0 { - // This may return a blocking error. - if err := s.fetchReadView(); err != nil { - s.readMu.Unlock() - return done, err.ToError() - } - - // Write to the underlying file. - n, err := dst.Write(s.readView) - done += int64(n) - count -= int64(n) - if dup { - // That's all we support for dup. This is generally - // supported by any Linux system calls, but the - // expectation is that now a caller will call read to - // actually remove these bytes from the socket. - break - } - - // Drop that part of the view. - s.readView.TrimFront(n) - if err != nil { - s.readMu.Unlock() - return done, err - } + // This may return a blocking error. + n, _, err := s.readLocked(dst, int(count), dup /* peek */) + if err != nil { + return 0, err.ToError() } - - s.readMu.Unlock() - return done, nil + return int64(n), nil } // ioSequencePayload implements tcpip.Payload. @@ -627,17 +576,7 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader // Readiness returns a mask of ready events for socket s. func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { - r := s.Endpoint.Readiness(mask) - - // Check our cached value iff the caller asked for readability and the - // endpoint itself is currently not readable. - if (mask & ^r & waiter.EventIn) != 0 { - if atomic.LoadUint32(&s.readViewHasData) == 1 { - r |= waiter.EventIn - } - } - - return r + return s.Endpoint.Readiness(mask) } func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error { @@ -2618,66 +2557,20 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * return a, l, nil } -// coalescingRead is the fast path for non-blocking, non-peek, stream-based -// case. It coalesces as many packets as possible before returning to the -// caller. +// streamRead is the fast path for non-blocking, non-peek, stream-based socket. // // Precondition: s.readMu must be locked. -func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) { - var err *syserr.Error - var copied int - - // Copy as many views as possible into the user-provided buffer. - for { - // Always do at least one fetchReadView, even if the number of bytes to - // read is 0. - err = s.fetchReadView() - if err != nil || len(s.readView) == 0 { - break - } - if dst.NumBytes() == 0 { - break - } - - var n int - var e error - if discard { - n = len(s.readView) - if int64(n) > dst.NumBytes() { - n = int(dst.NumBytes()) - } - } else { - n, e = dst.CopyOut(ctx, s.readView) - // Set the control message, even if 0 bytes were read. - if e == nil { - s.updateTimestamp() - } - } - copied += n - s.readView.TrimFront(n) - - dst = dst.DropFirst(n) - if e != nil { - err = syserr.FromError(e) - break - } - // If we are done reading requested data then stop. - if dst.NumBytes() == 0 { - break - } - } - - if len(s.readView) == 0 { - atomic.StoreUint32(&s.readViewHasData, 0) +func (s *socketOpsCommon) streamRead(ctx context.Context, dst io.Writer, count int) (int, *syserr.Error) { + // Always do at least one read, even if the number of bytes to read is 0. + var n int + n, _, err := s.readLocked(dst, count, false /* peek */) + if err != nil { + return 0, err } - - // If we managed to copy something, we must deliver it. - if copied > 0 { - s.Endpoint.ModerateRecvBuf(copied) - return copied, nil + if n > 0 { + s.Endpoint.ModerateRecvBuf(n) } - - return 0, err + return n, nil } func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) { @@ -2689,7 +2582,7 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) { return } cmsg.IP.HasInq = true - cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) + cmsg.IP.Inq = int32(rcvBufUsed) } func toLinuxPacketType(pktType tcpip.PacketType) uint8 { @@ -2726,7 +2619,21 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq // bytes of data to be discarded, rather than passed back in a // caller-supplied buffer. s.readMu.Lock() - n, err := s.coalescingRead(ctx, dst, trunc) + + var w io.Writer + if trunc { + w = ioutil.Discard + } else { + w = dst.Writer(ctx) + } + + n, err := s.streamRead(ctx, w, int(dst.NumBytes())) + + if err == nil && !trunc { + // Set the control message, even if 0 bytes were read. + s.updateTimestamp() + } + cmsg := s.controlMessages() s.fillCmsgInq(&cmsg) s.readMu.Unlock() @@ -2736,18 +2643,32 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq s.readMu.Lock() defer s.readMu.Unlock() - if err := s.fetchReadView(); err != nil { + // MSG_TRUNC with MSG_PEEK on a TCP socket returns the + // amount that could be read, and does not write to buffer. + isTCPPeekTrunc := !isPacket && peek && trunc + + var w io.Writer + if isTCPPeekTrunc { + w = ioutil.Discard + } else { + w = dst.Writer(ctx) + } + + var numRead, numTotal int + var err *syserr.Error + numRead, numTotal, err = s.readLocked(w, int(dst.NumBytes()), peek) + if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, err } - if !isPacket && peek && trunc { - // MSG_TRUNC with MSG_PEEK on a TCP socket returns the - // amount that could be read. + if isTCPPeekTrunc { + // TCP endpoint does not return the total bytes in buffer as numTotal. + // We need to query it from socket option. rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption) if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err) } - available := len(s.readView) + int(rql) + available := int(rql) bufLen := int(dst.NumBytes()) if available < bufLen { return available, 0, nil, 0, socket.ControlMessages{}, nil @@ -2755,11 +2676,9 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq return bufLen, 0, nil, 0, socket.ControlMessages{}, nil } - n, err := dst.CopyOut(ctx, s.readView) // Set the control message, even if 0 bytes were read. - if err == nil { - s.updateTimestamp() - } + s.updateTimestamp() + var addr linux.SockAddr var addrLen uint32 if isPacket && senderRequested { @@ -2772,58 +2691,33 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq } if peek { - if l := len(s.readView); trunc && l > n { + if trunc && numTotal > numRead { // isPacket must be true. - return l, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), syserr.FromError(err) + return numTotal, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), nil } - - if isPacket || err != nil { - return n, 0, addr, addrLen, s.controlMessages(), syserr.FromError(err) - } - - // We need to peek beyond the first message. - dst = dst.DropFirst(n) - num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) { - n, err := s.Endpoint.Peek(dsts) - // TODO(b/78348848): Handle peek timestamp. - if err != nil { - return int64(n), syserr.TranslateNetstackError(err).ToError() - } - return int64(n), nil - }}) - n += int(num) - if err == syserror.ErrWouldBlock && n > 0 { - // We got some data, so no need to return an error. - err = nil - } - return n, 0, nil, 0, s.controlMessages(), syserr.FromError(err) + return numRead, 0, nil, 0, s.controlMessages(), nil } var msgLen int if isPacket { - msgLen = len(s.readView) - s.readView = nil + msgLen = numTotal } else { - msgLen = int(n) - s.readView.TrimFront(int(n)) - } - - if len(s.readView) == 0 { - atomic.StoreUint32(&s.readViewHasData, 0) + msgLen = numRead } var flags int - if msgLen > int(n) { + if msgLen > numRead { flags |= linux.MSG_TRUNC } + n := numRead if trunc { n = msgLen } cmsg := s.controlMessages() s.fillCmsgInq(&cmsg) - return n, flags, addr, addrLen, cmsg, syserr.FromError(err) + return n, flags, addr, addrLen, cmsg, nil } func (s *socketOpsCommon) controlMessages() socket.ControlMessages { @@ -3090,11 +2984,6 @@ func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.Sy return 0, syserr.TranslateNetstackError(terr).ToError() } - // Add bytes removed from the endpoint but not yet sent to the caller. - s.readMu.Lock() - v += len(s.readView) - s.readMu.Unlock() - if v > math.MaxInt32 { v = math.MaxInt32 } diff --git a/pkg/syserr/netstack.go b/pkg/syserr/netstack.go index 77c3c110c..2756d4471 100644 --- a/pkg/syserr/netstack.go +++ b/pkg/syserr/netstack.go @@ -48,6 +48,7 @@ var ( ErrInvalidOptionValue = New(tcpip.ErrInvalidOptionValue.String(), linux.EINVAL) ErrBroadcastDisabled = New(tcpip.ErrBroadcastDisabled.String(), linux.EACCES) ErrNotPermittedNet = New(tcpip.ErrNotPermitted.String(), linux.EPERM) + ErrBadBuffer = New(tcpip.ErrBadBuffer.String(), linux.EFAULT) ) var netstackErrorTranslations map[string]*Error @@ -100,6 +101,7 @@ func init() { addErrMapping(tcpip.ErrBroadcastDisabled, ErrBroadcastDisabled) addErrMapping(tcpip.ErrNotPermitted, ErrNotPermittedNet) addErrMapping(tcpip.ErrAddressFamilyNotSupported, ErrAddressFamilyNotSupported) + addErrMapping(tcpip.ErrBadBuffer, ErrBadBuffer) } // TranslateNetstackError converts an error from the tcpip package to a sentry diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 4f551cd92..7193f56ad 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -286,45 +286,47 @@ type opErrorer interface { // commonRead implements the common logic between net.Conn.Read and // net.PacketConn.ReadFrom. -func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer, dontWait bool) ([]byte, error) { +func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) (int, error) { select { case <-deadline: - return nil, errorer.newOpError("read", &timeoutError{}) + return 0, errorer.newOpError("read", &timeoutError{}) default: } - read, _, err := ep.Read(addr) + w := tcpip.SliceWriter(b) + opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil} + res, err := ep.Read(&w, len(b), opts) if err == tcpip.ErrWouldBlock { - if dontWait { - return nil, errWouldBlock - } // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) wq.EventRegister(&waitEntry, waiter.EventIn) defer wq.EventUnregister(&waitEntry) for { - read, _, err = ep.Read(addr) + res, err = ep.Read(&w, len(b), opts) if err != tcpip.ErrWouldBlock { break } select { case <-deadline: - return nil, errorer.newOpError("read", &timeoutError{}) + return 0, errorer.newOpError("read", &timeoutError{}) case <-notifyCh: } } } if err == tcpip.ErrClosedForReceive { - return nil, io.EOF + return 0, io.EOF } if err != nil { - return nil, errorer.newOpError("read", errors.New(err.String())) + return 0, errorer.newOpError("read", errors.New(err.String())) } - return read, nil + if addr != nil { + *addr = res.RemoteAddr + } + return res.Count, nil } // Read implements net.Conn.Read. @@ -334,31 +336,11 @@ func (c *TCPConn) Read(b []byte) (int, error) { deadline := c.readCancel() - numRead := 0 - defer func() { - if numRead != 0 { - c.ep.ModerateRecvBuf(numRead) - } - }() - for numRead != len(b) { - if len(c.read) == 0 { - var err error - c.read, err = commonRead(c.ep, c.wq, deadline, nil, c, numRead != 0) - if err != nil { - if numRead != 0 { - return numRead, nil - } - return numRead, err - } - } - n := copy(b[numRead:], c.read) - c.read.TrimFront(n) - numRead += n - if len(c.read) == 0 { - c.read = nil - } + n, err := commonRead(b, c.ep, c.wq, deadline, nil, c) + if n != 0 { + c.ep.ModerateRecvBuf(n) } - return numRead, nil + return n, err } // Write implements net.Conn.Write. @@ -652,12 +634,11 @@ func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) { deadline := c.readCancel() var addr tcpip.FullAddress - read, err := commonRead(c.ep, c.wq, deadline, &addr, c, false) + n, err := commonRead(b, c.ep, c.wq, deadline, &addr, c) if err != nil { return 0, nil, err } - - return copy(b, read), fullToUDPAddr(addr), nil + return n, fullToUDPAddr(addr), nil } func (c *UDPConn) Write(b []byte) (int, error) { diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8db70a700..5dd1b1b6b 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -105,18 +105,18 @@ func (vv *VectorisedView) TrimFront(count int) { } // Read implements io.Reader. -func (vv *VectorisedView) Read(v View) (copied int, err error) { - count := len(v) +func (vv *VectorisedView) Read(b []byte) (copied int, err error) { + count := len(b) for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { vv.size -= count - copy(v[copied:], vv.views[0][:count]) + copy(b[copied:], vv.views[0][:count]) vv.views[0].TrimFront(count) copied += count return copied, nil } count -= len(vv.views[0]) - copy(v[copied:], vv.views[0]) + copy(b[copied:], vv.views[0]) copied += len(vv.views[0]) vv.removeFirst() } @@ -145,6 +145,35 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int return copied } +// ReadTo reads up to count bytes from vv to dst. It also removes them from vv +// unless peek is true. +func (vv *VectorisedView) ReadTo(dst io.Writer, count int, peek bool) (int, error) { + var err error + done := 0 + for _, v := range vv.Views() { + remaining := count - done + if remaining <= 0 { + break + } + if len(v) > remaining { + v = v[:remaining] + } + + var n int + n, err = dst.Write(v) + if n > 0 { + done += n + } + if err != nil { + break + } + } + if !peek { + vv.TrimFront(done) + } + return done, err +} + // CapLength irreversibly reduces the length of the vectorised view. func (vv *VectorisedView) CapLength(length int) { if length < 0 { diff --git a/pkg/tcpip/buffer/view_test.go b/pkg/tcpip/buffer/view_test.go index 726e54de9..e0ef8a94d 100644 --- a/pkg/tcpip/buffer/view_test.go +++ b/pkg/tcpip/buffer/view_test.go @@ -235,14 +235,16 @@ func TestToClone(t *testing.T) { } } -func TestVVReadToVV(t *testing.T) { - testCases := []struct { - comment string - vv VectorisedView - bytesToRead int - wantBytes string - leftVV VectorisedView - }{ +type readToTestCases struct { + comment string + vv VectorisedView + bytesToRead int + wantBytes string + leftVV VectorisedView +} + +func createReadToTestCases() []readToTestCases { + return []readToTestCases{ { comment: "large VV, short read", vv: vv(30, "012345678901234567890123456789"), @@ -279,8 +281,10 @@ func TestVVReadToVV(t *testing.T) { leftVV: vv(0, ""), }, } +} - for _, tc := range testCases { +func TestVVReadToVV(t *testing.T) { + for _, tc := range createReadToTestCases() { t.Run(tc.comment, func(t *testing.T) { var readTo VectorisedView inSize := tc.vv.Size() @@ -301,6 +305,52 @@ func TestVVReadToVV(t *testing.T) { } } +func TestVVReadTo(t *testing.T) { + for _, tc := range createReadToTestCases() { + t.Run(tc.comment, func(t *testing.T) { + var dst bytes.Buffer + origSize := tc.vv.Size() + copied, err := tc.vv.ReadTo(&dst, tc.bytesToRead, false /* peek */) + if got, want := copied, len(tc.wantBytes); err != nil || got != want { + t.Errorf("got ReadTo(&dst, %d, false) = %d, %v; want %d, nil", tc.bytesToRead, got, err, want) + } + if got, want := string(dst.Bytes()), tc.wantBytes; got != want { + t.Errorf("got dst = %q, want %q", got, want) + } + if got, want := tc.vv.Size(), origSize-copied; got != want { + t.Errorf("got after-read tc.vv.Size() = %d, want %d", got, want) + } + if got, want := string(tc.vv.ToView()), string(tc.leftVV.ToView()); got != want { + t.Errorf("got after-read data in tc.vv = %q, want %q", got, want) + } + }) + } +} + +func TestVVReadToPeek(t *testing.T) { + for _, tc := range createReadToTestCases() { + t.Run(tc.comment, func(t *testing.T) { + var dst bytes.Buffer + origSize := tc.vv.Size() + origData := string(tc.vv.ToView()) + copied, err := tc.vv.ReadTo(&dst, tc.bytesToRead, true /* peek */) + if got, want := copied, len(tc.wantBytes); err != nil || got != want { + t.Errorf("got ReadTo(&dst, %d, false) = %d, %v; want %d, nil", tc.bytesToRead, got, err, want) + } + if got, want := string(dst.Bytes()), tc.wantBytes; got != want { + t.Errorf("got dst = %q, want %q", got, want) + } + // Expect tc.vv is unchanged. + if got, want := tc.vv.Size(), origSize; got != want { + t.Errorf("got after-read tc.vv.Size() = %d, want %d", got, want) + } + if got, want := string(tc.vv.ToView()), origData; got != want { + t.Errorf("got after-read data in tc.vv = %q, want %q", got, want) + } + }) + } +} + func TestVVRead(t *testing.T) { testCases := []struct { comment string diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 91971b687..0ac2000ca 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -1603,3 +1603,15 @@ func IPv6RouterAlert(want header.IPv6RouterAlertValue) IPv6ExtHdrOptionChecker { } } } + +// IgnoreCmpPath returns a cmp.Option that ignores listed field paths. +func IgnoreCmpPath(paths ...string) cmp.Option { + ignores := map[string]struct{}{} + for _, path := range paths { + ignores[path] = struct{}{} + } + return cmp.FilterPath(func(path cmp.Path) bool { + _, ok := ignores[path.String()] + return ok + }, cmp.Ignore()) +} diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index ef62fe6fc..1c4919b1e 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -15,9 +15,11 @@ package ipv4_test import ( + "bytes" "context" "encoding/hex" "fmt" + "io/ioutil" "math" "net" "testing" @@ -2408,18 +2410,26 @@ func TestReceiveFragments(t *testing.T) { t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) } + const rcvSize = 65536 // Account for reassembled packets. for i, expectedPayload := range test.expectedPayloads { - gotPayload, _, err := ep.Read(nil) + var buf bytes.Buffer + result, err := ep.Read(&buf, rcvSize, tcpip.ReadOptions{}) if err != nil { - t.Fatalf("(i=%d) Read(nil): %s", i, err) + t.Fatalf("(i=%d) Read: %s", i, err) } - if diff := cmp.Diff(buffer.View(expectedPayload), gotPayload); diff != "" { + if diff := cmp.Diff(tcpip.ReadResult{ + Count: len(expectedPayload), + Total: len(expectedPayload), + }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("(i=%d) ep.Read: unexpected result (-want +got):\n%s", i, diff) + } + if diff := cmp.Diff(expectedPayload, buf.Bytes()); diff != "" { t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff) } } - if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + if res, err := ep.Read(ioutil.Discard, rcvSize, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { + t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) } }) } diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 5f07d3af8..360025b20 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -15,8 +15,10 @@ package ipv6 import ( + "bytes" "encoding/hex" "fmt" + "io/ioutil" "math" "net" "testing" @@ -844,13 +846,14 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { }, } + const mtu = header.IPv6MinimumMTU for _, test := range tests { t.Run(test.name, func(t *testing.T) { s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) + e := channel.New(1, mtu, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -979,17 +982,24 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { if got := stats.Value(); got != 1 { t.Errorf("got UDP Rx Packets = %d, want = 1", got) } - gotPayload, _, err := ep.Read(nil) + var buf bytes.Buffer + result, err := ep.Read(&buf, mtu, tcpip.ReadOptions{}) if err != nil { - t.Fatalf("Read(nil): %s", err) + t.Fatalf("Read: %s", err) + } + if diff := cmp.Diff(tcpip.ReadResult{ + Count: len(udpPayload), + Total: len(udpPayload), + }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("Read: unexpected result (-want +got):\n%s", diff) } - if diff := cmp.Diff(buffer.View(udpPayload), gotPayload); diff != "" { + if diff := cmp.Diff(udpPayload, buf.Bytes()); diff != "" { t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } // Should not have any more UDP packets. - if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + if res, err := ep.Read(ioutil.Discard, mtu, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { + t.Fatalf("got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) } }) } @@ -1969,18 +1979,20 @@ func TestReceiveIPv6Fragments(t *testing.T) { t.Errorf("got UDP Rx Packets = %d, want = %d", got, want) } + const rcvSize = 65536 // Account for reassembled packets. for i, p := range test.expectedPayloads { - gotPayload, _, err := ep.Read(nil) + var buf bytes.Buffer + _, err := ep.Read(&buf, rcvSize, tcpip.ReadOptions{}) if err != nil { - t.Fatalf("(i=%d) Read(nil): %s", i, err) + t.Fatalf("(i=%d) Read: %s", i, err) } - if diff := cmp.Diff(buffer.View(p), gotPayload); diff != "" { + if diff := cmp.Diff(p, buf.Bytes()); diff != "" { t.Errorf("(i=%d) got UDP payload mismatch (-want +got):\n%s", i, diff) } } - if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("(last) got Read(nil) = (%x, _, %v), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + if res, err := ep.Read(ioutil.Discard, rcvSize, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { + t.Fatalf("(last) got Read = (%v, %v), want = (_, %s)", res, err, tcpip.ErrWouldBlock) } }) } diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 51d428049..4777163cd 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -44,6 +44,7 @@ import ( "bufio" "fmt" "log" + "math" "math/rand" "net" "os" @@ -200,7 +201,7 @@ func main() { // connection from its side. wq.EventRegister(&waitEntry, waiter.EventIn) for { - v, _, err := ep.Read(nil) + _, err := ep.Read(os.Stdout, math.MaxUint16, tcpip.ReadOptions{}) if err != nil { if err == tcpip.ErrClosedForReceive { break @@ -213,8 +214,6 @@ func main() { log.Fatal("Read() failed:", err) } - - os.Stdout.Write(v) } wq.EventUnregister(&waitEntry) diff --git a/pkg/tcpip/sample/tun_tcp_echo/main.go b/pkg/tcpip/sample/tun_tcp_echo/main.go index 1c2afd554..a80fa0474 100644 --- a/pkg/tcpip/sample/tun_tcp_echo/main.go +++ b/pkg/tcpip/sample/tun_tcp_echo/main.go @@ -20,8 +20,10 @@ package main import ( + "bytes" "flag" "log" + "math" "math/rand" "net" "os" @@ -54,7 +56,8 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) { defer wq.EventUnregister(&waitEntry) for { - v, _, err := ep.Read(nil) + var buf bytes.Buffer + _, err := ep.Read(&buf, math.MaxUint16, tcpip.ReadOptions{}) if err != nil { if err == tcpip.ErrWouldBlock { <-notifyCh @@ -64,7 +67,7 @@ func echo(wq *waiter.Queue, ep tcpip.Endpoint) { return } - ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) + ep.Write(tcpip.SlicePayload(buf.Bytes()), tcpip.WriteOptions{}) } } diff --git a/pkg/tcpip/stack/transport_demuxer_test.go b/pkg/tcpip/stack/transport_demuxer_test.go index 737d8d912..859278f0b 100644 --- a/pkg/tcpip/stack/transport_demuxer_test.go +++ b/pkg/tcpip/stack/transport_demuxer_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "io/ioutil" "math" "math/rand" "testing" @@ -351,7 +352,7 @@ func TestBindToDeviceDistribution(t *testing.T) { } ep := <-pollChannel - if _, _, err := ep.Read(nil); err != nil { + if _, err := ep.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != nil { t.Fatalf("Read on endpoint %d failed: %s", eps[ep], err) } stats[ep]++ diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index dd552b8b9..a5facf578 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -15,6 +15,7 @@ package stack_test import ( + "io" "testing" "gvisor.dev/gvisor/pkg/tcpip" @@ -85,8 +86,8 @@ func (*fakeTransportEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask return mask } -func (*fakeTransportEndpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - return buffer.View{}, tcpip.ControlMessages{}, nil +func (*fakeTransportEndpoint) Read(io.Writer, int, tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { + return tcpip.ReadResult{}, nil } func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { @@ -110,10 +111,6 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return int64(len(v)), nil, nil } -func (*fakeTransportEndpoint) Peek([][]byte) (int64, *tcpip.Error) { - return 0, nil -} - // SetSockOpt sets a socket option. Currently not supported. func (*fakeTransportEndpoint) SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error { return tcpip.ErrInvalidEndpointState diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index ef0f51f1a..f798056c0 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -31,6 +31,7 @@ package tcpip import ( "errors" "fmt" + "io" "math/bits" "reflect" "strconv" @@ -39,7 +40,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/waiter" ) @@ -113,6 +113,7 @@ var ( ErrNotPermitted = &Error{msg: "operation not permitted"} ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} ErrMalformedHeader = &Error{msg: "header is malformed"} + ErrBadBuffer = &Error{msg: "bad buffer"} ) var messageToError map[string]*Error @@ -162,6 +163,7 @@ func StringToError(s string) *Error { ErrNotPermitted, ErrAddressFamilyNotSupported, ErrMalformedHeader, + ErrBadBuffer, } messageToError = make(map[string]*Error) @@ -496,6 +498,21 @@ func (s SlicePayload) Payload(size int) ([]byte, *Error) { return s[:size], nil } +var _ io.Writer = (*SliceWriter)(nil) + +// SliceWriter implements io.Writer for slices. +type SliceWriter []byte + +// Write implements io.Writer.Write. +func (s *SliceWriter) Write(b []byte) (int, error) { + n := copy(*s, b) + *s = (*s)[n:] + if n < len(b) { + return n, io.ErrShortWrite + } + return n, nil +} + // A ControlMessages contains socket control messages for IP sockets. // // +stateify savable @@ -552,6 +569,40 @@ type PacketOwner interface { GID() uint32 } +// ReadOptions contains options for Endpoint.Read. +type ReadOptions struct { + // Peek indicates whether this read is a peek. + Peek bool + + // NeedRemoteAddr indicates whether to return the remote address, if + // supported. + NeedRemoteAddr bool + + // NeedLinkPacketInfo indicates whether to return the link-layer information, + // if supported. + NeedLinkPacketInfo bool +} + +// ReadResult represents result for a successful Endpoint.Read. +type ReadResult struct { + // Count is the number of bytes received and written to the buffer. + Count int + + // Total is the number of bytes of the received packet. This can be used to + // determine whether the read is truncated. + Total int + + // ControlMessages is the control messages received. + ControlMessages ControlMessages + + // RemoteAddr is the remote address if ReadOptions.NeedAddr is true. + RemoteAddr FullAddress + + // LinkPacketInfo is the link-layer information of the received packet if + // ReadOptions.NeedLinkPacketInfo is true. + LinkPacketInfo LinkPacketInfo +} + // Endpoint is the interface implemented by transport protocols (e.g., tcp, udp) // that exposes functionality like read, write, connect, etc. to users of the // networking stack. @@ -566,11 +617,15 @@ type Endpoint interface { // Abort is best effort; implementing Abort with Close is acceptable. Abort() - // Read reads data from the endpoint and optionally returns the sender. + // Read reads data from the endpoint and optionally writes to dst. + // + // This method does not block if there is no data pending; in this case, + // ErrWouldBlock is returned. // - // This method does not block if there is no data pending. It will also - // either return an error or data, never both. - Read(*FullAddress) (buffer.View, ControlMessages, *Error) + // If non-zero number of bytes are successfully read and written to dst, err + // must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer + // should be returned. + Read(dst io.Writer, count int, opts ReadOptions) (res ReadResult, err *Error) // Write writes data to the endpoint's peer. This method does not block if // the data cannot be written. @@ -592,11 +647,6 @@ type Endpoint interface { // not). The channel is only non-nil in this case. Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error) - // Peek reads data without consuming it from the endpoint. - // - // This method does not block if there is no data pending. - Peek([][]byte) (int64, *Error) - // Connect connects the endpoint to its peer. Specifying a NIC is // optional. // @@ -703,17 +753,6 @@ type LinkPacketInfo struct { PktType PacketType } -// PacketEndpoint are additional methods that are only implemented by Packet -// endpoints. -type PacketEndpoint interface { - // ReadPacket reads a datagram/packet from the endpoint and optionally - // returns the sender and additional LinkPacketInfo. - // - // This method does not block if there is no data pending. It will also - // either return an error or data, never both. - ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error) -} - // EndpointInfo is the interface implemented by each endpoint info struct. type EndpointInfo interface { // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index bb3b2ed0d..ca1e88e99 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -15,6 +15,7 @@ go_test( deps = [ "//pkg/tcpip", "//pkg/tcpip/buffer", + "//pkg/tcpip/checker", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/ethernet", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 907565ac4..60054d6ef 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -15,12 +15,13 @@ package integration_test import ( + "bytes" "net" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/ethernet" "gvisor.dev/gvisor/pkg/tcpip/link/nested" @@ -382,24 +383,33 @@ func TestForwarding(t *testing.T) { // Wait for the endpoint to be readable. <-ch - var addr tcpip.FullAddress - v, _, err := ep.Read(&addr) + var buf bytes.Buffer + opts := tcpip.ReadOptions{NeedRemoteAddr: true} + res, err := ep.Read(&buf, len(data), opts) if err != nil { - t.Fatalf("ep.Read(_): %s", err) + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) } - if diff := cmp.Diff(v, buffer.View(data)); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: len(data), + Total: len(data), + RemoteAddr: tcpip.FullAddress{Addr: expectedFrom}, + }, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) } - if addr.Addr != expectedFrom { - t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, expectedFrom) + if diff := cmp.Diff(buf.Bytes(), data); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) } if t.Failed() { t.FailNow() } - return addr + return res.RemoteAddr } addr := read(epsAndAddrs.serverReadableCH, epsAndAddrs.serverEP, data, epsAndAddrs.clientAddr) diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index b41b72381..209da3903 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -15,12 +15,13 @@ package integration_test import ( + "bytes" "net" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/pipe" "gvisor.dev/gvisor/pkg/tcpip/network/arp" @@ -86,21 +87,21 @@ func TestPing(t *testing.T) { transProto tcpip.TransportProtocolNumber netProto tcpip.NetworkProtocolNumber remoteAddr tcpip.Address - icmpBuf func(*testing.T) buffer.View + icmpBuf func(*testing.T) []byte }{ { name: "IPv4 Ping", transProto: icmp.ProtocolNumber4, netProto: ipv4.ProtocolNumber, remoteAddr: ipv4Addr2.AddressWithPrefix.Address, - icmpBuf: func(t *testing.T) buffer.View { + icmpBuf: func(t *testing.T) []byte { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) hdr.SetType(header.ICMPv4Echo) if n := copy(hdr.Payload(), data[:]); n != len(data) { t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) } - return buffer.View(hdr) + return hdr }, }, { @@ -108,14 +109,14 @@ func TestPing(t *testing.T) { transProto: icmp.ProtocolNumber6, netProto: ipv6.ProtocolNumber, remoteAddr: ipv6Addr2.AddressWithPrefix.Address, - icmpBuf: func(t *testing.T) buffer.View { + icmpBuf: func(t *testing.T) []byte { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) hdr.SetType(header.ICMPv6EchoRequest) if n := copy(hdr.Payload(), data[:]); n != len(data) { t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) } - return buffer.View(hdr) + return hdr }, }, } @@ -200,16 +201,25 @@ func TestPing(t *testing.T) { // Wait for the endpoint to be readable. <-waiterCH - var addr tcpip.FullAddress - v, _, err := ep.Read(&addr) + var buf bytes.Buffer + opts := tcpip.ReadOptions{NeedRemoteAddr: true} + res, err := ep.Read(&buf, len(icmpBuf), opts) if err != nil { - t.Fatalf("ep.Read(_): %s", err) + t.Fatalf("ep.Read(_, %d, %#v): %s", len(icmpBuf), opts, err) } - if diff := cmp.Diff(v[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + RemoteAddr: tcpip.FullAddress{Addr: test.remoteAddr}, + }, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) } - if addr.Addr != test.remoteAddr { - t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.remoteAddr) + if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], icmpBuf[icmpDataOffset:]); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) } }) } diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index baaa741cd..cf9e86c3c 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -15,12 +15,14 @@ package integration_test import ( + "bytes" "testing" "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" @@ -238,21 +240,28 @@ func TestLoopbackAcceptAllInSubnetUDP(t *testing.T) { t.Fatalf("got sep.Write(_, _) = (%d, _, nil), want = (%d, _, nil)", n, want) } - var addr tcpip.FullAddress - if gotPayload, _, err := rep.Read(&addr); test.expectRx { + var buf bytes.Buffer + opts := tcpip.ReadOptions{NeedRemoteAddr: true} + if res, err := rep.Read(&buf, len(data), opts); test.expectRx { if err != nil { - t.Fatalf("reep.Read(_): %s", err) - } - if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + t.Fatalf("rep.Read(_, %d, %#v): %s", len(data), opts, err) } - if addr.Addr != test.addAddress.AddressWithPrefix.Address { - t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.addAddress.AddressWithPrefix.Address) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + RemoteAddr: tcpip.FullAddress{ + Addr: test.addAddress.AddressWithPrefix.Address, + }, + }, res, + checker.IgnoreCmpPath("ControlMessages", "RemoteAddr.NIC", "RemoteAddr.Port"), + ); diff != "" { + t.Errorf("rep.Read: unexpected result (-want +got):\n%s", diff) } - } else { - if err != tcpip.ErrWouldBlock { - t.Fatalf("got rep.Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + if diff := cmp.Diff(data, buf.Bytes()); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } + } else if err != tcpip.ErrWouldBlock { + t.Fatalf("got rep.Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock) } }) } diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index 20f8a7e6c..fae6c256a 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -15,12 +15,14 @@ package integration_test import ( + "bytes" "net" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -462,17 +464,23 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { } test.rxUDP(e, test.remoteAddr, test.dstAddr, data) - if gotPayload, _, err := ep.Read(nil); test.expectRx { + var buf bytes.Buffer + var opts tcpip.ReadOptions + if res, err := ep.Read(&buf, len(data), opts); test.expectRx { if err != nil { - t.Fatalf("Read(nil): %s", err) + t.Fatalf("ep.Read(_, %d, %#v): %s", len(data), opts, err) } - if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) } - } else { - if err != tcpip.ErrWouldBlock { - t.Fatalf("got Read(nil) = (%x, _, %s), want = (_, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + if diff := cmp.Diff(data, buf.Bytes()); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) } + } else if err != tcpip.ErrWouldBlock { + t.Fatalf("got Read = (%v, %s) [with data %x], want = (_, %s)", res, err, buf.Bytes(), tcpip.ErrWouldBlock) } }) } @@ -589,9 +597,19 @@ func TestReuseAddrAndBroadcast(t *testing.T) { // Wait for the endpoint to become readable. <-rep.ch - if gotPayload, _, err := rep.ep.Read(nil); err != nil { - t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err) - } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { + var buf bytes.Buffer + result, err := rep.ep.Read(&buf, len(data), tcpip.ReadOptions{}) + if err != nil { + t.Errorf("(eps[%d] write) eps[%d].Read: %s", i, j, err) + continue + } + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("(eps[%d] write) eps[%d].Read: unexpected result (-want +got):\n%s", i, j, diff) + } + if diff := cmp.Diff([]byte(data), buf.Bytes()); diff != "" { t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff) } } @@ -719,10 +737,20 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { t.Fatalf("ep.SetSockOpt(&%#v): %s", addOpt, err) } test.rxUDP(e, test.remoteAddr, test.multicastAddr, data) - if gotPayload, _, err := ep.Read(nil); err != nil { - t.Fatalf("ep.Read(nil): %s", err) - } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { - t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + var buf bytes.Buffer + result, err := ep.Read(&buf, len(data), tcpip.ReadOptions{}) + if err != nil { + t.Fatalf("ep.Read: %s", err) + } else { + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) + } + if diff := cmp.Diff(data, buf.Bytes()); diff != "" { + t.Errorf("got UDP payload mismatch (-want +got):\n%s", diff) + } } // We should not receive UDP packets to the group once we leave @@ -731,8 +759,8 @@ func TestUDPAddRemoveMembershipSocketOption(t *testing.T) { if err := ep.SetSockOpt(&removeOpt); err != nil { t.Fatalf("ep.SetSockOpt(&%#v): %s", removeOpt, err) } - if gotPayload, _, err := ep.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got ep.Read(nil) = (%x, _, %s), want = (nil, _, %s)", gotPayload, err, tcpip.ErrWouldBlock) + if _, err := ep.Read(&buf, 1, tcpip.ReadOptions{}); err != tcpip.ErrWouldBlock { + t.Fatalf("got ep.Read = (_, %s), want = (_, %s)", err, tcpip.ErrWouldBlock) } }) } diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go index 02fc47015..52cf89b54 100644 --- a/pkg/tcpip/tests/integration/route_test.go +++ b/pkg/tcpip/tests/integration/route_test.go @@ -15,11 +15,14 @@ package integration_test import ( + "bytes" + "math" "testing" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/checker" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/loopback" @@ -203,16 +206,25 @@ func TestLocalPing(t *testing.T) { // Wait for the endpoint to become readable. <-ch - var addr tcpip.FullAddress - v, _, err := ep.Read(&addr) + var buf bytes.Buffer + opts := tcpip.ReadOptions{NeedRemoteAddr: true} + res, err := ep.Read(&buf, math.MaxUint16, opts) if err != nil { - t.Fatalf("ep.Read(_): %s", err) + t.Fatalf("ep.Read(_, %d, %#v): %s", math.MaxUint16, opts, err) } - if diff := cmp.Diff(v[icmpDataOffset:], buffer.View(payload[icmpDataOffset:])); diff != "" { - t.Errorf("received data mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + RemoteAddr: tcpip.FullAddress{Addr: test.localAddr}, + }, res, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) } - if addr.Addr != test.localAddr { - t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.localAddr) + if diff := cmp.Diff(buf.Bytes()[icmpDataOffset:], []byte(payload[icmpDataOffset:])); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) } test.checkLinkEndpoint(t, e) @@ -338,14 +350,27 @@ func TestLocalUDP(t *testing.T) { <-serverCH var clientAddr tcpip.FullAddress - if v, _, err := server.Read(&clientAddr); err != nil { + var readBuf bytes.Buffer + if read, err := server.Read(&readBuf, math.MaxUint16, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil { t.Fatalf("server.Read(_): %s", err) } else { - if diff := cmp.Diff(buffer.View(clientPayload), v); diff != "" { - t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff) + clientAddr = read.RemoteAddr + + if diff := cmp.Diff(tcpip.ReadResult{ + Count: readBuf.Len(), + Total: readBuf.Len(), + RemoteAddr: tcpip.FullAddress{ + Addr: test.canBePrimaryAddr.AddressWithPrefix.Address, + }, + }, read, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + t.Errorf("server.Read: unexpected result (-want +got):\n%s", diff) } - if clientAddr.Addr != test.canBePrimaryAddr.AddressWithPrefix.Address { - t.Errorf("got clientAddr.Addr = %s, want = %s", clientAddr.Addr, test.canBePrimaryAddr.AddressWithPrefix.Address) + if diff := cmp.Diff(buffer.View(clientPayload), buffer.View(readBuf.Bytes())); diff != "" { + t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff) } if t.Failed() { t.FailNow() @@ -367,15 +392,23 @@ func TestLocalUDP(t *testing.T) { // Wait for the client endpoint to become readable. <-clientCH - var gotServerAddr tcpip.FullAddress - if v, _, err := client.Read(&gotServerAddr); err != nil { + readBuf.Reset() + if read, err := client.Read(&readBuf, math.MaxUint16, tcpip.ReadOptions{NeedRemoteAddr: true}); err != nil { t.Fatalf("client.Read(_): %s", err) } else { - if diff := cmp.Diff(buffer.View(serverPayload), v); diff != "" { - t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: readBuf.Len(), + Total: readBuf.Len(), + RemoteAddr: tcpip.FullAddress{Addr: serverAddr.Addr}, + }, read, checker.IgnoreCmpPath( + "ControlMessages", + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + t.Errorf("client.Read: unexpected result (-want +got):\n%s", diff) } - if gotServerAddr.Addr != serverAddr.Addr { - t.Errorf("got gotServerAddr.Addr = %s, want = %s", gotServerAddr.Addr, serverAddr.Addr) + if diff := cmp.Diff(buffer.View(serverPayload), buffer.View(readBuf.Bytes())); diff != "" { + t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff) } if t.Failed() { t.FailNow() diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index d1e4a7cb7..2eb4457df 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,6 +15,8 @@ package icmp import ( + "io" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -151,9 +153,8 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// Read reads data from the endpoint. This method does not block if -// there is no data pending. -func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.Endpoint.Read. +func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { e.rcvMu.Lock() if e.rcvList.Empty() { @@ -163,20 +164,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() - return buffer.View{}, tcpip.ControlMessages{}, err + return tcpip.ReadResult{}, err } p := e.rcvList.Front() - e.rcvList.Remove(p) - e.rcvBufSize -= p.data.Size() + if !opts.Peek { + e.rcvList.Remove(p) + e.rcvBufSize -= p.data.Size() + } e.rcvMu.Unlock() - if addr != nil { - *addr = p.senderAddress + res := tcpip.ReadResult{ + Total: p.data.Size(), + ControlMessages: tcpip.ControlMessages{ + HasTimestamp: true, + Timestamp: p.timestamp, + }, + } + if opts.NeedRemoteAddr { + res.RemoteAddr = p.senderAddress } - return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil + n, err := p.data.ReadTo(dst, count, opts.Peek) + if n == 0 && err != nil { + return res, tcpip.ErrBadBuffer + } + res.Count = n + return res, nil } // prepareForWrite prepares the endpoint for sending data. In particular, it @@ -329,11 +344,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return int64(len(v)), nil, nil } -// Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { - return 0, nil -} - // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { return nil diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index e5e247342..3ab060751 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -26,6 +26,7 @@ package packet import ( "fmt" + "io" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -160,8 +161,8 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// Read implements tcpip.PacketEndpoint.ReadPacket. -func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -173,29 +174,37 @@ func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketIn err = tcpip.ErrClosedForReceive } ep.rcvMu.Unlock() - return buffer.View{}, tcpip.ControlMessages{}, err + return tcpip.ReadResult{}, err } packet := ep.rcvList.Front() - ep.rcvList.Remove(packet) - ep.rcvBufSize -= packet.data.Size() + if !opts.Peek { + ep.rcvList.Remove(packet) + ep.rcvBufSize -= packet.data.Size() + } ep.rcvMu.Unlock() - if addr != nil { - *addr = packet.senderAddr + res := tcpip.ReadResult{ + Total: packet.data.Size(), + ControlMessages: tcpip.ControlMessages{ + HasTimestamp: true, + Timestamp: packet.timestampNS, + }, } - - if info != nil { - *info = packet.packetInfo + if opts.NeedRemoteAddr { + res.RemoteAddr = packet.senderAddr + } + if opts.NeedLinkPacketInfo { + res.LinkPacketInfo = packet.packetInfo } - return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil -} - -// Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - return ep.ReadPacket(addr, nil) + n, err := packet.data.ReadTo(dst, count, opts.Peek) + if n == 0 && err != nil { + return res, tcpip.ErrBadBuffer + } + res.Count = n + return res, nil } func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { @@ -203,11 +212,6 @@ func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-cha return 0, nil, tcpip.ErrInvalidOptionValue } -// Peek implements tcpip.Endpoint.Peek. -func (*endpoint) Peek([][]byte) (int64, *tcpip.Error) { - return 0, nil -} - // Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be // disconnected, and this function always returns tpcip.ErrNotSupported. func (*endpoint) Disconnect() *tcpip.Error { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 7befcfc9b..dd260535f 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -27,6 +27,7 @@ package raw import ( "fmt" + "io" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -190,7 +191,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { e.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -202,20 +203,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() - return buffer.View{}, tcpip.ControlMessages{}, err + return tcpip.ReadResult{}, err } pkt := e.rcvList.Front() - e.rcvList.Remove(pkt) - e.rcvBufSize -= pkt.data.Size() + if !opts.Peek { + e.rcvList.Remove(pkt) + e.rcvBufSize -= pkt.data.Size() + } e.rcvMu.Unlock() - if addr != nil { - *addr = pkt.senderAddr + res := tcpip.ReadResult{ + Total: pkt.data.Size(), + ControlMessages: tcpip.ControlMessages{ + HasTimestamp: true, + Timestamp: pkt.timestampNS, + }, + } + if opts.NeedRemoteAddr { + res.RemoteAddr = pkt.senderAddr } - return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil + n, err := pkt.data.ReadTo(dst, count, opts.Peek) + if n == 0 && err != nil { + return res, tcpip.ErrBadBuffer + } + res.Count = n + return res, nil } // Write implements tcpip.Endpoint.Write. @@ -363,11 +378,6 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, return int64(len(payloadBytes)), nil, nil } -// Peek implements tcpip.Endpoint.Peek. -func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { - return 0, nil -} - // Disconnect implements tcpip.Endpoint.Disconnect. func (*endpoint) Disconnect() *tcpip.Error { return tcpip.ErrNotSupported diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index cf232b508..7e81203ba 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -112,6 +112,7 @@ go_test( "//pkg/tcpip/transport/tcp/testing/context", "//pkg/test/testutil", "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 6e3c8860e..8f3981075 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -17,6 +17,7 @@ package tcp import ( "encoding/binary" "fmt" + "io" "math" "runtime" "strings" @@ -27,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/ports" @@ -393,15 +393,28 @@ type endpoint struct { lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` - // The following fields are used to manage the receive queue. The - // protocol goroutine adds ready-for-delivery segments to rcvList, - // which are returned by Read() calls to users. + // rcvReadMu synchronizes calls to Read. // - // Once the peer has closed its send side, rcvClosed is set to true - // to indicate to users that no more data is coming. + // mu and rcvListMu are temporarily released during data copying. rcvReadMu + // must be held during each read to ensure atomicity, so that multiple reads + // do not interleave. + // + // rcvReadMu should be held before holding mu. + rcvReadMu sync.Mutex `state:"nosave"` + + // rcvListMu synchronizes access to rcvList. // // rcvListMu can be taken after the endpoint mu below. - rcvListMu sync.Mutex `state:"nosave"` + rcvListMu sync.Mutex `state:"nosave"` + + // rcvList is the queue for ready-for-delivery segments. + // + // rcvReadMu, mu and rcvListMu must be held, in the stated order, to read data + // and removing segments from list. A range of segment can be determined, then + // temporarily release mu and rcvListMu while processing the segment range. + // This allows new segments to be appended to the list while processing. + // + // rcvListMu must be held to append segments to list. rcvList segmentList `state:"wait"` rcvClosed bool // rcvBufSize is the total size of the receive buffer. @@ -1309,8 +1322,69 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) { e.UnlockUser() } -// Read reads data from the endpoint. -func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.Endpoint.Read. +func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { + e.rcvReadMu.Lock() + defer e.rcvReadMu.Unlock() + + // N.B. Here we get a range of segments to be processed. It is safe to not + // hold rcvListMu when processing, since we hold rcvReadMu to ensure only we + // can remove segments from the list through commitRead(). + first, last, serr := e.startRead() + if serr != nil { + if serr == tcpip.ErrClosedForReceive { + e.stats.ReadErrors.ReadClosed.Increment() + } + return tcpip.ReadResult{}, serr + } + + var err error + done := 0 + s := first + for s != nil && done < count { + var n int + n, err = s.data.ReadTo(dst, count-done, opts.Peek) + // Book keeping first then error handling. + + done += n + + if opts.Peek { + // For peek, we use the (first, last) range of segment returned from + // startRead. We don't consume the receive buffer, so commitRead should + // not be called. + // + // N.B. It is important to use `last` to determine the last segment, since + // appending can happen while we process, and will lead to data race. + if s == last { + break + } + s = s.Next() + } else { + // N.B. commitRead() conveniently returns the next segment to read, after + // removing the data/segment that is read. + s = e.commitRead(n) + } + + if err != nil { + break + } + } + + // If something is read, we must report it. Report error when nothing is read. + if done == 0 && err != nil { + return tcpip.ReadResult{}, tcpip.ErrBadBuffer + } + return tcpip.ReadResult{ + Count: done, + Total: done, + }, nil +} + +// startRead checks that endpoint is in a readable state, and return the +// inclusive range of segments that can be read. +// +// Precondition: e.rcvReadMu must be held. +func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -1319,7 +1393,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, // on a receive. It can expect to read any data after the handshake // is complete. RFC793, section 3.9, p58. if e.EndpointState() == StateSynSent { - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock + return nil, nil, tcpip.ErrWouldBlock } // The endpoint can be read if it's connected, or if it's already closed @@ -1327,61 +1401,69 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, // would cause the state to become StateError so we should allow the // reads to proceed before returning a ECONNRESET. e.rcvListMu.Lock() + defer e.rcvListMu.Unlock() + bufUsed := e.rcvBufUsed if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { - e.rcvListMu.Unlock() if s == StateError { if err := e.hardErrorLocked(); err != nil { - return buffer.View{}, tcpip.ControlMessages{}, err + return nil, nil, err } - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive + return nil, nil, tcpip.ErrClosedForReceive } e.stats.ReadErrors.NotConnected.Increment() - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected + return nil, nil, tcpip.ErrNotConnected } - v, err := e.readLocked() - e.rcvListMu.Unlock() - - if err == tcpip.ErrClosedForReceive { - e.stats.ReadErrors.ReadClosed.Increment() - } - return v, tcpip.ControlMessages{}, err -} - -func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { if e.rcvBufUsed == 0 { if e.rcvClosed || !e.EndpointState().connected() { - return buffer.View{}, tcpip.ErrClosedForReceive + return nil, nil, tcpip.ErrClosedForReceive } - return buffer.View{}, tcpip.ErrWouldBlock + return nil, nil, tcpip.ErrWouldBlock } - s := e.rcvList.Front() - views := s.data.Views() - v := views[s.viewToDeliver] - s.viewToDeliver++ + return e.rcvList.Front(), e.rcvList.Back(), nil +} + +// commitRead commits a read of done bytes and returns the next non-empty +// segment to read. Data read from the segment must have also been removed from +// the segment in order for this method to work correctly. +// +// It is performance critical to call commitRead frequently when servicing a big +// Read request, so TCP can make progress timely. Right now, it is designed to +// do this per segment read, hence this method conveniently returns the next +// segment to read while holding the lock. +// +// Precondition: e.rcvReadMu must be held. +func (e *endpoint) commitRead(done int) *segment { + e.LockUser() + defer e.UnlockUser() + e.rcvListMu.Lock() + defer e.rcvListMu.Unlock() - var delta int - if s.viewToDeliver >= len(views) { + memDelta := 0 + s := e.rcvList.Front() + for s != nil && s.data.Size() == 0 { e.rcvList.Remove(s) - // We only free up receive buffer space when the segment is released as the - // segment is still holding on to the views even though some views have been - // read out to the user. - delta = s.segMemSize() + // Memory is only considered released when the whole segment has been + // read. + memDelta += s.segMemSize() s.decRef() + s = e.rcvList.Front() } + e.rcvBufUsed -= done - e.rcvBufUsed -= len(v) - // If the window was small before this read and if the read freed up - // enough buffer space, to either fit an aMSS or half a receive buffer - // (whichever smaller), then notify the protocol goroutine to send a - // window update. - if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above { - e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) + if memDelta > 0 { + // If the window was small before this read and if the read freed up + // enough buffer space, to either fit an aMSS or half a receive buffer + // (whichever smaller), then notify the protocol goroutine to send a + // window update. + if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above { + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) + } } - return v, nil + return e.rcvList.Front() } // isEndpointWritableLocked checks if a given endpoint is writable @@ -1499,64 +1581,6 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return queueAndSend() } -// Peek reads data without consuming it from the endpoint. -// -// This method does not block if there is no data pending. -func (e *endpoint) Peek(vec [][]byte) (int64, *tcpip.Error) { - e.LockUser() - defer e.UnlockUser() - - // The endpoint can be read if it's connected, or if it's already closed - // but has some pending unread data. - if s := e.EndpointState(); !s.connected() && s != StateClose { - if s == StateError { - return 0, e.hardErrorLocked() - } - e.stats.ReadErrors.InvalidEndpointState.Increment() - return 0, tcpip.ErrInvalidEndpointState - } - - e.rcvListMu.Lock() - defer e.rcvListMu.Unlock() - - if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.EndpointState().connected() { - e.stats.ReadErrors.ReadClosed.Increment() - return 0, tcpip.ErrClosedForReceive - } - return 0, tcpip.ErrWouldBlock - } - - // Make a copy of vec so we can modify the slide headers. - vec = append([][]byte(nil), vec...) - - var num int64 - for s := e.rcvList.Front(); s != nil; s = s.Next() { - views := s.data.Views() - - for i := s.viewToDeliver; i < len(views); i++ { - v := views[i] - - for len(v) > 0 { - if len(vec) == 0 { - return num, nil - } - if len(vec[0]) == 0 { - vec = vec[1:] - continue - } - - n := copy(vec[0], v) - v = v[n:] - vec[0] = vec[0][n:] - num += int64(n) - } - } - } - - return num, nil -} - // selectWindowLocked returns the new window without checking for shrinking or scaling // applied. // Precondition: e.mu and e.rcvListMu must be held. diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 5ef73ec74..c5a6d2fba 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -37,7 +37,7 @@ const ( // segment represents a TCP segment. It holds the payload and parsed TCP segment // information, and can be added to intrusive lists. -// segment is mostly immutable, the only field allowed to change is viewToDeliver. +// segment is mostly immutable, the only field allowed to change is data. // // +stateify savable type segment struct { @@ -60,10 +60,7 @@ type segment struct { hdr header.TCP // views is used as buffer for data when its length is large // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` - // viewToDeliver keeps track of the next View that should be - // delivered by the Read endpoint. - viewToDeliver int + views [8]buffer.View `state:"nosave"` sequenceNumber seqnum.Value ackNumber seqnum.Value flags uint8 @@ -84,6 +81,9 @@ type segment struct { // acked indicates if the segment has already been SACKed. acked bool + + // dataMemSize is the memory used by data initially. + dataMemSize int } func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { @@ -100,6 +100,7 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) * s.data = pkt.Data.Clone(s.views[:]) s.hdr = header.TCP(pkt.TransportHeader().View()) s.rcvdTime = time.Now() + s.dataMemSize = s.data.Size() return s } @@ -113,6 +114,7 @@ func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment { s.views[0] = v s.data = buffer.NewVectorisedView(len(v), s.views[:1]) } + s.dataMemSize = s.data.Size() return s } @@ -127,12 +129,12 @@ func (s *segment) clone() *segment { netProto: s.netProto, nicID: s.nicID, remoteLinkAddr: s.remoteLinkAddr, - viewToDeliver: s.viewToDeliver, rcvdTime: s.rcvdTime, xmitTime: s.xmitTime, xmitCount: s.xmitCount, ep: s.ep, qFlags: s.qFlags, + dataMemSize: s.dataMemSize, } t.data = s.data.Clone(t.views[:]) return t @@ -204,7 +206,7 @@ func (s *segment) payloadSize() int { // segMemSize is the amount of memory used to hold the segment data and // the associated metadata. func (s *segment) segMemSize() int { - return SegSize + s.data.Size() + return SegSize + s.dataMemSize } // parse populates the sequence & ack numbers, flags, and window fields of the diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go index 7dc2741a6..7422d8c02 100644 --- a/pkg/tcpip/transport/tcp/segment_state.go +++ b/pkg/tcpip/transport/tcp/segment_state.go @@ -24,16 +24,11 @@ import ( func (s *segment) saveData() buffer.VectorisedView { // We cannot save s.data directly as s.data.views may alias to s.views, // which is not allowed by state framework (in-struct pointer). - v := make([]buffer.View, len(s.data.Views())) - // For views already delivered, we cannot save them directly as they may - // have already been sliced and saved elsewhere (e.g., readViews). - for i := 0; i < s.viewToDeliver; i++ { - v[i] = append([]byte(nil), s.data.Views()[i]...) + vs := make([]buffer.View, len(s.data.Views())) + for i, v := range s.data.Views() { + vs[i] = v } - for i := s.viewToDeliver; i < len(v); i++ { - v[i] = s.data.Views()[i] - } - return buffer.NewVectorisedView(s.data.Size(), v) + return buffer.NewVectorisedView(s.data.Size(), vs) } // loadData is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index cf60d5b53..9fa4672d7 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -17,10 +17,12 @@ package tcp_test import ( "bytes" "fmt" + "io/ioutil" "math" "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/rand" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -40,6 +42,64 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// endpointTester provides helper functions to test a tcpip.Endpoint. +type endpointTester struct { + ep tcpip.Endpoint +} + +// CheckReadError issues a read to the endpoint and checking for an error. +func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) { + t.Helper() + res, got := e.ep.Read(ioutil.Discard, 1, tcpip.ReadOptions{}) + if got != want { + t.Fatalf("ep.Read = %s, want %s", got, want) + } + if diff := cmp.Diff(tcpip.ReadResult{}, res); diff != "" { + t.Errorf("ep.Read: unexpected non-zero result (-want +got):\n%s", diff) + } +} + +// CheckRead issues a read to the endpoint and checking for a success, returning +// the data read. +func (e *endpointTester) CheckRead(t *testing.T, count int) []byte { + t.Helper() + var buf bytes.Buffer + res, err := e.ep.Read(&buf, count, tcpip.ReadOptions{}) + if err != nil { + t.Fatalf("ep.Read = _, %s; want _, nil", err) + } + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff) + } + return buf.Bytes() +} + +// CheckReadFull reads from the endpoint for exactly count bytes. +func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte { + t.Helper() + var buf bytes.Buffer + var done int + for done < count { + res, err := e.ep.Read(&buf, count-done, tcpip.ReadOptions{}) + if err == tcpip.ErrWouldBlock { + // Wait for receive to be notified. + select { + case <-notifyRead: + case <-time.After(timeout): + t.Fatalf("Timed out waiting for data to arrive") + } + continue + } else if err != nil { + t.Fatalf("ep.Read = _, %s; want _, nil", err) + } + done += res.Count + } + return buf.Bytes() +} + const ( // defaultMTU is the MTU, in bytes, used throughout the tests, except // where another value is explicitly used. It is chosen to match the MTU @@ -740,9 +800,7 @@ func TestSimpleReceive(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -762,11 +820,7 @@ func TestSimpleReceive(t *testing.T) { } // Receive data. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - + v := ept.CheckRead(t, defaultMTU) if !bytes.Equal(data, v) { t.Fatalf("got data = %v, want = %v", v, data) } @@ -1492,14 +1546,11 @@ func TestSynSent(t *testing.T) { t.Fatal("timed out waiting for packet to arrive") } + ept := endpointTester{c.EP} if test.reset { - if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused) - } + ept.CheckReadError(t, tcpip.ErrConnectionRefused) } else { - if _, _, err := c.EP.Read(nil); err != tcpip.ErrAborted { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrAborted) - } + ept.CheckReadError(t, tcpip.ErrAborted) } if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 { @@ -1524,9 +1575,8 @@ func TestOutOfOrderReceive(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) // Send second half of data first, with seqnum 3 ahead of expected. data := []byte{1, 2, 3, 4, 5, 6} @@ -1551,9 +1601,7 @@ func TestOutOfOrderReceive(t *testing.T) { // Wait 200ms and check that no data has been received. time.Sleep(200 * time.Millisecond) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept.CheckReadError(t, tcpip.ErrWouldBlock) // Send the first 3 bytes now. c.SendPacket(data[:3], &context.Headers{ @@ -1566,24 +1614,7 @@ func TestOutOfOrderReceive(t *testing.T) { }) // Receive data. - read := make([]byte, 0, 6) - for len(read) < len(data) { - v, _, err := c.EP.Read(nil) - if err != nil { - if err == tcpip.ErrWouldBlock { - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - continue - } - t.Fatalf("Read failed: %s", err) - } - - read = append(read, v...) - } + read := ept.CheckReadFull(t, 6, ch, 5*time.Second) // Check that we received the data in proper order. if !bytes.Equal(data, read) { @@ -1608,9 +1639,8 @@ func TestOutOfOrderFlood(t *testing.T) { rcvBufSz := math.MaxUint16 c.CreateConnected(789, 30000, rcvBufSz) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) // Send 100 packets before the actual one that is expected. data := []byte{1, 2, 3, 4, 5, 6} @@ -1685,9 +1715,8 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -1754,9 +1783,8 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -1837,17 +1865,14 @@ func TestShutdownRead(t *testing.T) { c.CreateConnected(789, 30000, -1 /* epRcvBuf */) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil { t.Fatalf("Shutdown failed: %s", err) } - if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive) - } + ept.CheckReadError(t, tcpip.ErrClosedForReceive) var want uint64 = 1 if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want { t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want) @@ -1865,10 +1890,8 @@ func TestFullWindowReceive(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - _, _, err := c.EP.Read(nil) - if err != tcpip.ErrWouldBlock { - t.Fatalf("Read failed: %s", err) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) // Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies // the provided buffer value by tcp.SegOverheadFactor to calculate the actual @@ -1905,11 +1928,7 @@ func TestFullWindowReceive(t *testing.T) { ) // Receive data and check it. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - + v := ept.CheckRead(t, defaultMTU) if !bytes.Equal(data, v) { t.Fatalf("got data = %v, want = %v", v, data) } @@ -1991,8 +2010,9 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) { // Read the data so that the subsequent ACK from the endpoint // grows the right edge of the window. - if _, _, err := c.EP.Read(nil); err != nil { - t.Fatalf("got Read(nil) = %s", err) + var buf bytes.Buffer + if _, err := c.EP.Read(&buf, math.MaxUint16, tcpip.ReadOptions{}); err != nil { + t.Fatalf("c.EP.Read: %s", err) } // Check if we have received max uint16 as our advertised @@ -2027,9 +2047,9 @@ func TestNoWindowShrinking(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) + // Send a 1 byte payload so that we can record the current receive window. // Send a payload of half the size of rcvBufSize. seqNum := iss.Add(1) @@ -2051,11 +2071,7 @@ func TestNoWindowShrinking(t *testing.T) { } // Read the 1 byte payload we just sent. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - if got, want := payload, v; !bytes.Equal(got, want) { + if got, want := payload, ept.CheckRead(t, 1); !bytes.Equal(got, want) { t.Fatalf("got data: %v, want: %v", got, want) } @@ -2128,24 +2144,8 @@ func TestNoWindowShrinking(t *testing.T) { ), ) - // Wait for receive to be notified. - select { - case <-ch: - case <-time.After(5 * time.Second): - t.Fatalf("Timed out waiting for data to arrive") - } - // Receive data and check it. - read := make([]byte, 0, rcvBufSize) - for len(read) < len(data) { - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - - read = append(read, v...) - } - + read := ept.CheckReadFull(t, len(data), ch, 5*time.Second) if !bytes.Equal(data, read) { t.Fatalf("got data = %v, want = %v", read, data) } @@ -2569,11 +2569,11 @@ func TestZeroScaledWindowReceive(t *testing.T) { // we need to read at 3 packets. sz := 0 for sz < defaultMTU*2 { - v, _, err := c.EP.Read(nil) + res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}) if err != nil { t.Fatalf("Read failed: %s", err) } - sz += len(v) + sz += res.Count } checker.IPv4(t, c.GetPacket(), @@ -3268,13 +3268,13 @@ func TestReceiveOnResetConnection(t *testing.T) { loop: for { - switch _, _, err := c.EP.Read(nil); err { + switch _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err { case tcpip.ErrWouldBlock: select { case <-ch: // Expect the state to be StateError and subsequent Reads to fail with HardError. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset) + if _, err := c.EP.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset { + t.Fatalf("got c.EP.Read() = %s, want = %s", err, tcpip.ErrConnectionReset) } break loop case <-time.After(1 * time.Second): @@ -4164,9 +4164,8 @@ func TestReadAfterClosedState(t *testing.T) { c.WQ.EventRegister(&we, waiter.EventIn) defer c.WQ.EventUnregister(&we) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) // Shutdown immediately for write, check that we get a FIN. if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil { @@ -4224,35 +4223,31 @@ func TestReadAfterClosedState(t *testing.T) { } // Check that peek works. - peekBuf := make([]byte, 10) - n, err := c.EP.Peek([][]byte{peekBuf}) + var peekBuf bytes.Buffer + res, err := c.EP.Read(&peekBuf, 10, tcpip.ReadOptions{Peek: true}) if err != nil { t.Fatalf("Peek failed: %s", err) } - peekBuf = peekBuf[:n] - if !bytes.Equal(data, peekBuf) { - t.Fatalf("got data = %v, want = %v", peekBuf, data) + if got, want := res.Count, len(data); got != want { + t.Fatalf("res.Count = %d, want %d", got, want) } - - // Receive data. - v, _, err := c.EP.Read(nil) - if err != nil { - t.Fatalf("Read failed: %s", err) + if !bytes.Equal(data, peekBuf.Bytes()) { + t.Fatalf("got data = %v, want = %v", peekBuf.Bytes(), data) } + // Receive data. + v := ept.CheckRead(t, defaultMTU) if !bytes.Equal(data, v) { t.Fatalf("got data = %v, want = %v", v, data) } // Now that we drained the queue, check that functions fail with the // right error code. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive) - } - - if _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive { - t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive) + ept.CheckReadError(t, tcpip.ErrClosedForReceive) + var buf bytes.Buffer + if _, err := c.EP.Read(&buf, 1, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive { + t.Fatalf("c.EP.Read(_, _, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive) } } @@ -4619,17 +4614,8 @@ func TestSelfConnect(t *testing.T) { // Read back what was written. wq.EventUnregister(&waitEntry) wq.EventRegister(&waitEntry, waiter.EventIn) - rd, _, err := ep.Read(nil) - if err != nil { - if err != tcpip.ErrWouldBlock { - t.Fatalf("Read failed: %s", err) - } - <-notifyCh - rd, _, err = ep.Read(nil) - if err != nil { - t.Fatalf("Read failed: %s", err) - } - } + ept := endpointTester{ep} + rd := ept.CheckReadFull(t, len(data), notifyCh, 5*time.Second) if !bytes.Equal(data, rd) { t.Fatalf("got data = %v, want = %v", rd, data) @@ -5082,9 +5068,8 @@ func TestKeepalive(t *testing.T) { } // Check that the connection is still alive. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) // Send some data and wait before ACKing it. Keepalives should be disabled // during this period. @@ -5173,9 +5158,7 @@ func TestKeepalive(t *testing.T) { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got) } - if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout) - } + ept.CheckReadError(t, tcpip.ErrTimeout) if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 { t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got) @@ -6070,9 +6053,8 @@ func TestEndpointBindListenAcceptState(t *testing.T) { t.Errorf("unexpected endpoint state: want %s, got %s", want, got) } - if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected { - t.Errorf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrNotConnected) - } + ept := endpointTester{ep} + ept.CheckReadError(t, tcpip.ErrNotConnected) if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 { t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1) } @@ -6227,7 +6209,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) { // Now read all the data from the endpoint and verify that advertised // window increases to the full available buffer size. for { - _, _, err := c.EP.Read(nil) + _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}) if err == tcpip.ErrWouldBlock { break } @@ -6351,11 +6333,11 @@ func TestReceiveBufferAutoTuning(t *testing.T) { // to happen before we measure the new window. totalCopied := 0 for { - b, _, err := c.EP.Read(nil) + res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}) if err == tcpip.ErrWouldBlock { break } - totalCopied += len(b) + totalCopied += res.Count } // Invoke the moderation API. This is required for auto-tuning @@ -7272,9 +7254,8 @@ func TestTCPUserTimeout(t *testing.T) { ), ) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrTimeout) if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) @@ -7317,9 +7298,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { } // Check that the connection is still alive. - if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock) - } + ept := endpointTester{c.EP} + ept.CheckReadError(t, tcpip.ErrWouldBlock) // Now receive 1 keepalives, but don't ACK it. b := c.GetPacket() @@ -7358,9 +7338,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) { ), ) - if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout { - t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout) - } + ept.CheckReadError(t, tcpip.ErrTimeout) if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want { t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want) } @@ -7417,11 +7395,11 @@ func TestIncreaseWindowOnRead(t *testing.T) { // defaultMTU is a good enough estimate for the MSS used for this // connection. for read < defaultMTU*2 { - v, _, err := c.EP.Read(nil) + res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}) if err != nil { t.Fatalf("Read failed: %s", err) } - read += len(v) + read += res.Count } // After reading > MSS worth of data, we surely crossed MSS. See the ack: diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go index 0f9ed06cd..9e02d467d 100644 --- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go @@ -20,6 +20,7 @@ import ( "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -105,11 +106,18 @@ func TestTimeStampEnabledConnect(t *testing.T) { // There should be 5 views to read and each of them should // contain the same data. for i := 0; i < 5; i++ { - got, _, err := c.EP.Read(nil) + var buf bytes.Buffer + result, err := c.EP.Read(&buf, len(data), tcpip.ReadOptions{}) if err != nil { t.Fatalf("Unexpected error from Read: %v", err) } - if want := data; bytes.Compare(got, want) != 0 { + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("Read: unexpected result (-want +got):\n%s", diff) + } + if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 { t.Fatalf("Data is different: got: %v, want: %v", got, want) } } @@ -286,11 +294,18 @@ func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) { } // Issue a read and we should data. - got, _, err := c.EP.Read(nil) + var buf bytes.Buffer + result, err := c.EP.Read(&buf, defaultMTU, tcpip.ReadOptions{}) if err != nil { t.Fatalf("Unexpected error from Read: %v", err) } - if want := data; bytes.Compare(got, want) != 0 { + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" { + t.Errorf("Read: unexpected result (-want +got):\n%s", diff) + } + if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 { t.Fatalf("Data is different: got: %v, want: %v", got, want) } } diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD index 7ebae63d8..153e8c950 100644 --- a/pkg/tcpip/transport/udp/BUILD +++ b/pkg/tcpip/transport/udp/BUILD @@ -58,5 +58,6 @@ go_test( "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "//pkg/waiter", + "@com_github_google_go_cmp//cmp:go_default_library", ], ) diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 4e8bd8b04..075de1db0 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -16,6 +16,7 @@ package udp import ( "fmt" + "io" "sync/atomic" "gvisor.dev/gvisor/pkg/sync" @@ -282,11 +283,10 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} -// Read reads data from the endpoint. This method does not block if -// there is no data pending. -func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.Endpoint.Read. +func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { if err := e.LastError(); err != nil { - return buffer.View{}, tcpip.ControlMessages{}, err + return tcpip.ReadResult{}, err } e.rcvMu.Lock() @@ -298,18 +298,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() - return buffer.View{}, tcpip.ControlMessages{}, err + return tcpip.ReadResult{}, err } p := e.rcvList.Front() - e.rcvList.Remove(p) - e.rcvBufSize -= p.data.Size() - e.rcvMu.Unlock() - - if addr != nil { - *addr = p.senderAddress + if !opts.Peek { + e.rcvList.Remove(p) + e.rcvBufSize -= p.data.Size() } + e.rcvMu.Unlock() + // Control Messages cm := tcpip.ControlMessages{ HasTimestamp: true, Timestamp: p.timestamp, @@ -331,7 +330,22 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess cm.HasOriginalDstAddress = true cm.OriginalDstAddress = p.destinationAddress } - return p.data.ToView(), cm, nil + + // Read Result + res := tcpip.ReadResult{ + Total: p.data.Size(), + ControlMessages: cm, + } + if opts.NeedRemoteAddr { + res.RemoteAddr = p.senderAddress + } + + n, err := p.data.ReadTo(dst, count, opts.Peek) + if n == 0 && err != nil { + return res, tcpip.ErrBadBuffer + } + res.Count = n + return res, nil } // prepareForWrite prepares the endpoint for sending data. In particular, it @@ -566,11 +580,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return int64(len(v)), nil, nil } -// Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { - return 0, nil -} - // OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. func (e *endpoint) OnReuseAddressSet(v bool) { e.mu.Lock() diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index 8429f34b4..455b8c2aa 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -18,10 +18,12 @@ import ( "bytes" "context" "fmt" + "io/ioutil" "math/rand" "testing" "time" + "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" @@ -595,13 +597,13 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe // Take a snapshot of the stats to validate them at the end of the test. epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone() - var addr tcpip.FullAddress - v, cm, err := c.ep.Read(&addr) + var buf bytes.Buffer + res, err := c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true}) if err == tcpip.ErrWouldBlock { // Wait for data to become available. select { case <-ch: - v, cm, err = c.ep.Read(&addr) + res, err = c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true}) case <-time.After(300 * time.Millisecond): if packetShouldBeDropped { @@ -621,23 +623,32 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe } if packetShouldBeDropped { - c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr) + c.t.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr) } - // Check the peer address. + // Check the read result. h := flow.header4Tuple(incoming) - if addr.Addr != h.srcAddr.Addr { - c.t.Fatalf("got address = %s, want = %s", addr.Addr, h.srcAddr.Addr) + if diff := cmp.Diff(tcpip.ReadResult{ + Count: buf.Len(), + Total: buf.Len(), + RemoteAddr: tcpip.FullAddress{Addr: h.srcAddr.Addr}, + }, res, checker.IgnoreCmpPath( + "ControlMessages", // ControlMessages will be checked later. + "RemoteAddr.NIC", + "RemoteAddr.Port", + )); diff != "" { + c.t.Fatalf("Read: unexpected result (-want +got):\n%s", diff) } // Check the payload. + v := buf.Bytes() if !bytes.Equal(payload, v) { c.t.Fatalf("got payload = %x, want = %x", v, payload) } // Run any checkers against the ControlMessages. for _, f := range checkers { - f(c.t, cm) + f(c.t, res.ControlMessages) } c.checkEndpointReadStats(1, epstats, err) @@ -828,8 +839,8 @@ func TestV4ReadSelfSource(t *testing.T) { t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource) } - if _, _, err := c.ep.Read(nil); err != tt.wantErr { - t.Errorf("got c.ep.Read(nil) = %s, want = %s", err, tt.wantErr) + if _, err := c.ep.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err != tt.wantErr { + t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr) } }) } diff --git a/test/syscalls/linux/socket_bind_to_device_distribution.cc b/test/syscalls/linux/socket_bind_to_device_distribution.cc index 06419772f..f8a0a80f2 100644 --- a/test/syscalls/linux/socket_bind_to_device_distribution.cc +++ b/test/syscalls/linux/socket_bind_to_device_distribution.cc @@ -204,7 +204,7 @@ TEST_P(BindToDeviceDistributionTest, Tcp) { }); } - for (int i = 0; i < kConnectAttempts; i++) { + for (int32_t i = 0; i < kConnectAttempts; i++) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); ASSERT_THAT( @@ -212,22 +212,8 @@ TEST_P(BindToDeviceDistributionTest, Tcp) { connector.addr_len), SyscallSucceeds()); - // Do two separate sends to ensure two segments are received. This is - // required for netstack where read is incorrectly assuming a whole - // segment is read when endpoint.Read() is called which is technically - // incorrect as the syscall that invoked endpoint.Read() may only - // consume it partially. This results in a case where a close() of - // such a socket does not trigger a RST in netstack due to the - // endpoint assuming that the endpoint has no unread data. EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), SyscallSucceedsWithValue(sizeof(i))); - - // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly - // generates a RST. - if (IsRunningOnGvisor()) { - EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), - SyscallSucceedsWithValue(sizeof(i))); - } } // Join threads to be sure that all connections have been counted. diff --git a/test/syscalls/linux/socket_generic.cc b/test/syscalls/linux/socket_generic.cc index a28ee2233..de0b8bb11 100644 --- a/test/syscalls/linux/socket_generic.cc +++ b/test/syscalls/linux/socket_generic.cc @@ -43,6 +43,15 @@ TEST_P(AllSocketPairTest, BasicReadWrite) { EXPECT_EQ(data, absl::string_view(buf, 3)); } +TEST_P(AllSocketPairTest, BasicReadWriteBadBuffer) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + const std::string data = "abc"; + ASSERT_THAT(WriteFd(sockets->first_fd(), data.c_str(), 3), + SyscallSucceedsWithValue(3)); + ASSERT_THAT(ReadFd(sockets->second_fd(), nullptr, 3), + SyscallFailsWithErrno(EFAULT)); +} + TEST_P(AllSocketPairTest, BasicSendRecv) { auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); char sent_data[512]; diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 51b77ad85..a11147085 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -1507,7 +1507,7 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { } ScopedThread connecting_thread([&connector, &conn_addr]() { - for (int i = 0; i < kConnectAttempts; i++) { + for (int32_t i = 0; i < kConnectAttempts; i++) { const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( Socket(connector.family(), SOCK_STREAM, IPPROTO_TCP)); ASSERT_THAT( @@ -1515,22 +1515,8 @@ TEST_P(SocketInetReusePortTest, TcpPortReuseMultiThread_NoRandomSave) { connector.addr_len), SyscallSucceeds()); - // Do two separate sends to ensure two segments are received. This is - // required for netstack where read is incorrectly assuming a whole - // segment is read when endpoint.Read() is called which is technically - // incorrect as the syscall that invoked endpoint.Read() may only - // consume it partially. This results in a case where a close() of - // such a socket does not trigger a RST in netstack due to the - // endpoint assuming that the endpoint has no unread data. EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), SyscallSucceedsWithValue(sizeof(i))); - - // TODO(gvisor.dev/issue/1449): Remove this block once netstack correctly - // generates a RST. - if (IsRunningOnGvisor()) { - EXPECT_THAT(RetryEINTR(send)(fd.get(), &i, sizeof(i), 0), - SyscallSucceedsWithValue(sizeof(i))); - } } }); |