summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip/transport
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip/transport')
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go38
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go46
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go34
-rw-r--r--pkg/tcpip/transport/tcp/BUILD1
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go228
-rw-r--r--pkg/tcpip/transport/tcp/segment.go16
-rw-r--r--pkg/tcpip/transport/tcp/segment_state.go13
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go272
-rw-r--r--pkg/tcpip/transport/tcp/tcp_timestamp_test.go23
-rw-r--r--pkg/tcpip/transport/udp/BUILD1
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go43
-rw-r--r--pkg/tcpip/transport/udp/udp_test.go31
12 files changed, 403 insertions, 343 deletions
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index d1e4a7cb7..2eb4457df 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,6 +15,8 @@
package icmp
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -151,9 +153,8 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// Read reads data from the endpoint. This method does not block if
-// there is no data pending.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvMu.Lock()
if e.rcvList.Empty() {
@@ -163,20 +164,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
p := e.rcvList.Front()
- e.rcvList.Remove(p)
- e.rcvBufSize -= p.data.Size()
+ if !opts.Peek {
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+ }
e.rcvMu.Unlock()
- if addr != nil {
- *addr = p.senderAddress
+ res := tcpip.ReadResult{
+ Total: p.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: p.timestamp,
+ },
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = p.senderAddress
}
- return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+ n, err := p.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -329,11 +344,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return int64(len(v)), nil, nil
}
-// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
return nil
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index e5e247342..3ab060751 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -26,6 +26,7 @@ package packet
import (
"fmt"
+ "io"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -160,8 +161,8 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// Read implements tcpip.PacketEndpoint.ReadPacket.
-func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -173,29 +174,37 @@ func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketIn
err = tcpip.ErrClosedForReceive
}
ep.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
packet := ep.rcvList.Front()
- ep.rcvList.Remove(packet)
- ep.rcvBufSize -= packet.data.Size()
+ if !opts.Peek {
+ ep.rcvList.Remove(packet)
+ ep.rcvBufSize -= packet.data.Size()
+ }
ep.rcvMu.Unlock()
- if addr != nil {
- *addr = packet.senderAddr
+ res := tcpip.ReadResult{
+ Total: packet.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: packet.timestampNS,
+ },
}
-
- if info != nil {
- *info = packet.packetInfo
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = packet.senderAddr
+ }
+ if opts.NeedLinkPacketInfo {
+ res.LinkPacketInfo = packet.packetInfo
}
- return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
-}
-
-// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- return ep.ReadPacket(addr, nil)
+ n, err := packet.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
@@ -203,11 +212,6 @@ func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-cha
return 0, nil, tcpip.ErrInvalidOptionValue
}
-// Peek implements tcpip.Endpoint.Peek.
-func (*endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
// disconnected, and this function always returns tpcip.ErrNotSupported.
func (*endpoint) Disconnect() *tcpip.Error {
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 7befcfc9b..dd260535f 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -27,6 +27,7 @@ package raw
import (
"fmt"
+ "io"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -190,7 +191,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -202,20 +203,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
pkt := e.rcvList.Front()
- e.rcvList.Remove(pkt)
- e.rcvBufSize -= pkt.data.Size()
+ if !opts.Peek {
+ e.rcvList.Remove(pkt)
+ e.rcvBufSize -= pkt.data.Size()
+ }
e.rcvMu.Unlock()
- if addr != nil {
- *addr = pkt.senderAddr
+ res := tcpip.ReadResult{
+ Total: pkt.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: pkt.timestampNS,
+ },
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = pkt.senderAddr
}
- return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil
+ n, err := pkt.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// Write implements tcpip.Endpoint.Write.
@@ -363,11 +378,6 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
return int64(len(payloadBytes)), nil, nil
}
-// Peek implements tcpip.Endpoint.Peek.
-func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect.
func (*endpoint) Disconnect() *tcpip.Error {
return tcpip.ErrNotSupported
diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD
index cf232b508..7e81203ba 100644
--- a/pkg/tcpip/transport/tcp/BUILD
+++ b/pkg/tcpip/transport/tcp/BUILD
@@ -112,6 +112,7 @@ go_test(
"//pkg/tcpip/transport/tcp/testing/context",
"//pkg/test/testutil",
"//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 6e3c8860e..8f3981075 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -17,6 +17,7 @@ package tcp
import (
"encoding/binary"
"fmt"
+ "io"
"math"
"runtime"
"strings"
@@ -27,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
@@ -393,15 +393,28 @@ type endpoint struct {
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
- // The following fields are used to manage the receive queue. The
- // protocol goroutine adds ready-for-delivery segments to rcvList,
- // which are returned by Read() calls to users.
+ // rcvReadMu synchronizes calls to Read.
//
- // Once the peer has closed its send side, rcvClosed is set to true
- // to indicate to users that no more data is coming.
+ // mu and rcvListMu are temporarily released during data copying. rcvReadMu
+ // must be held during each read to ensure atomicity, so that multiple reads
+ // do not interleave.
+ //
+ // rcvReadMu should be held before holding mu.
+ rcvReadMu sync.Mutex `state:"nosave"`
+
+ // rcvListMu synchronizes access to rcvList.
//
// rcvListMu can be taken after the endpoint mu below.
- rcvListMu sync.Mutex `state:"nosave"`
+ rcvListMu sync.Mutex `state:"nosave"`
+
+ // rcvList is the queue for ready-for-delivery segments.
+ //
+ // rcvReadMu, mu and rcvListMu must be held, in the stated order, to read data
+ // and removing segments from list. A range of segment can be determined, then
+ // temporarily release mu and rcvListMu while processing the segment range.
+ // This allows new segments to be appended to the list while processing.
+ //
+ // rcvListMu must be held to append segments to list.
rcvList segmentList `state:"wait"`
rcvClosed bool
// rcvBufSize is the total size of the receive buffer.
@@ -1309,8 +1322,69 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) {
e.UnlockUser()
}
-// Read reads data from the endpoint.
-func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+ e.rcvReadMu.Lock()
+ defer e.rcvReadMu.Unlock()
+
+ // N.B. Here we get a range of segments to be processed. It is safe to not
+ // hold rcvListMu when processing, since we hold rcvReadMu to ensure only we
+ // can remove segments from the list through commitRead().
+ first, last, serr := e.startRead()
+ if serr != nil {
+ if serr == tcpip.ErrClosedForReceive {
+ e.stats.ReadErrors.ReadClosed.Increment()
+ }
+ return tcpip.ReadResult{}, serr
+ }
+
+ var err error
+ done := 0
+ s := first
+ for s != nil && done < count {
+ var n int
+ n, err = s.data.ReadTo(dst, count-done, opts.Peek)
+ // Book keeping first then error handling.
+
+ done += n
+
+ if opts.Peek {
+ // For peek, we use the (first, last) range of segment returned from
+ // startRead. We don't consume the receive buffer, so commitRead should
+ // not be called.
+ //
+ // N.B. It is important to use `last` to determine the last segment, since
+ // appending can happen while we process, and will lead to data race.
+ if s == last {
+ break
+ }
+ s = s.Next()
+ } else {
+ // N.B. commitRead() conveniently returns the next segment to read, after
+ // removing the data/segment that is read.
+ s = e.commitRead(n)
+ }
+
+ if err != nil {
+ break
+ }
+ }
+
+ // If something is read, we must report it. Report error when nothing is read.
+ if done == 0 && err != nil {
+ return tcpip.ReadResult{}, tcpip.ErrBadBuffer
+ }
+ return tcpip.ReadResult{
+ Count: done,
+ Total: done,
+ }, nil
+}
+
+// startRead checks that endpoint is in a readable state, and return the
+// inclusive range of segments that can be read.
+//
+// Precondition: e.rcvReadMu must be held.
+func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -1319,7 +1393,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
// on a receive. It can expect to read any data after the handshake
// is complete. RFC793, section 3.9, p58.
if e.EndpointState() == StateSynSent {
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ return nil, nil, tcpip.ErrWouldBlock
}
// The endpoint can be read if it's connected, or if it's already closed
@@ -1327,61 +1401,69 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
// would cause the state to become StateError so we should allow the
// reads to proceed before returning a ECONNRESET.
e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
bufUsed := e.rcvBufUsed
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
- e.rcvListMu.Unlock()
if s == StateError {
if err := e.hardErrorLocked(); err != nil {
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return nil, nil, err
}
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
+ return nil, nil, tcpip.ErrClosedForReceive
}
e.stats.ReadErrors.NotConnected.Increment()
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
+ return nil, nil, tcpip.ErrNotConnected
}
- v, err := e.readLocked()
- e.rcvListMu.Unlock()
-
- if err == tcpip.ErrClosedForReceive {
- e.stats.ReadErrors.ReadClosed.Increment()
- }
- return v, tcpip.ControlMessages{}, err
-}
-
-func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
if e.rcvBufUsed == 0 {
if e.rcvClosed || !e.EndpointState().connected() {
- return buffer.View{}, tcpip.ErrClosedForReceive
+ return nil, nil, tcpip.ErrClosedForReceive
}
- return buffer.View{}, tcpip.ErrWouldBlock
+ return nil, nil, tcpip.ErrWouldBlock
}
- s := e.rcvList.Front()
- views := s.data.Views()
- v := views[s.viewToDeliver]
- s.viewToDeliver++
+ return e.rcvList.Front(), e.rcvList.Back(), nil
+}
+
+// commitRead commits a read of done bytes and returns the next non-empty
+// segment to read. Data read from the segment must have also been removed from
+// the segment in order for this method to work correctly.
+//
+// It is performance critical to call commitRead frequently when servicing a big
+// Read request, so TCP can make progress timely. Right now, it is designed to
+// do this per segment read, hence this method conveniently returns the next
+// segment to read while holding the lock.
+//
+// Precondition: e.rcvReadMu must be held.
+func (e *endpoint) commitRead(done int) *segment {
+ e.LockUser()
+ defer e.UnlockUser()
+ e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
- var delta int
- if s.viewToDeliver >= len(views) {
+ memDelta := 0
+ s := e.rcvList.Front()
+ for s != nil && s.data.Size() == 0 {
e.rcvList.Remove(s)
- // We only free up receive buffer space when the segment is released as the
- // segment is still holding on to the views even though some views have been
- // read out to the user.
- delta = s.segMemSize()
+ // Memory is only considered released when the whole segment has been
+ // read.
+ memDelta += s.segMemSize()
s.decRef()
+ s = e.rcvList.Front()
}
+ e.rcvBufUsed -= done
- e.rcvBufUsed -= len(v)
- // If the window was small before this read and if the read freed up
- // enough buffer space, to either fit an aMSS or half a receive buffer
- // (whichever smaller), then notify the protocol goroutine to send a
- // window update.
- if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above {
- e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ if memDelta > 0 {
+ // If the window was small before this read and if the read freed up
+ // enough buffer space, to either fit an aMSS or half a receive buffer
+ // (whichever smaller), then notify the protocol goroutine to send a
+ // window update.
+ if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above {
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ }
}
- return v, nil
+ return e.rcvList.Front()
}
// isEndpointWritableLocked checks if a given endpoint is writable
@@ -1499,64 +1581,6 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return queueAndSend()
}
-// Peek reads data without consuming it from the endpoint.
-//
-// This method does not block if there is no data pending.
-func (e *endpoint) Peek(vec [][]byte) (int64, *tcpip.Error) {
- e.LockUser()
- defer e.UnlockUser()
-
- // The endpoint can be read if it's connected, or if it's already closed
- // but has some pending unread data.
- if s := e.EndpointState(); !s.connected() && s != StateClose {
- if s == StateError {
- return 0, e.hardErrorLocked()
- }
- e.stats.ReadErrors.InvalidEndpointState.Increment()
- return 0, tcpip.ErrInvalidEndpointState
- }
-
- e.rcvListMu.Lock()
- defer e.rcvListMu.Unlock()
-
- if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.EndpointState().connected() {
- e.stats.ReadErrors.ReadClosed.Increment()
- return 0, tcpip.ErrClosedForReceive
- }
- return 0, tcpip.ErrWouldBlock
- }
-
- // Make a copy of vec so we can modify the slide headers.
- vec = append([][]byte(nil), vec...)
-
- var num int64
- for s := e.rcvList.Front(); s != nil; s = s.Next() {
- views := s.data.Views()
-
- for i := s.viewToDeliver; i < len(views); i++ {
- v := views[i]
-
- for len(v) > 0 {
- if len(vec) == 0 {
- return num, nil
- }
- if len(vec[0]) == 0 {
- vec = vec[1:]
- continue
- }
-
- n := copy(vec[0], v)
- v = v[n:]
- vec[0] = vec[0][n:]
- num += int64(n)
- }
- }
- }
-
- return num, nil
-}
-
// selectWindowLocked returns the new window without checking for shrinking or scaling
// applied.
// Precondition: e.mu and e.rcvListMu must be held.
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 5ef73ec74..c5a6d2fba 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -37,7 +37,7 @@ const (
// segment represents a TCP segment. It holds the payload and parsed TCP segment
// information, and can be added to intrusive lists.
-// segment is mostly immutable, the only field allowed to change is viewToDeliver.
+// segment is mostly immutable, the only field allowed to change is data.
//
// +stateify savable
type segment struct {
@@ -60,10 +60,7 @@ type segment struct {
hdr header.TCP
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
- views [8]buffer.View `state:"nosave"`
- // viewToDeliver keeps track of the next View that should be
- // delivered by the Read endpoint.
- viewToDeliver int
+ views [8]buffer.View `state:"nosave"`
sequenceNumber seqnum.Value
ackNumber seqnum.Value
flags uint8
@@ -84,6 +81,9 @@ type segment struct {
// acked indicates if the segment has already been SACKed.
acked bool
+
+ // dataMemSize is the memory used by data initially.
+ dataMemSize int
}
func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
@@ -100,6 +100,7 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *
s.data = pkt.Data.Clone(s.views[:])
s.hdr = header.TCP(pkt.TransportHeader().View())
s.rcvdTime = time.Now()
+ s.dataMemSize = s.data.Size()
return s
}
@@ -113,6 +114,7 @@ func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment {
s.views[0] = v
s.data = buffer.NewVectorisedView(len(v), s.views[:1])
}
+ s.dataMemSize = s.data.Size()
return s
}
@@ -127,12 +129,12 @@ func (s *segment) clone() *segment {
netProto: s.netProto,
nicID: s.nicID,
remoteLinkAddr: s.remoteLinkAddr,
- viewToDeliver: s.viewToDeliver,
rcvdTime: s.rcvdTime,
xmitTime: s.xmitTime,
xmitCount: s.xmitCount,
ep: s.ep,
qFlags: s.qFlags,
+ dataMemSize: s.dataMemSize,
}
t.data = s.data.Clone(t.views[:])
return t
@@ -204,7 +206,7 @@ func (s *segment) payloadSize() int {
// segMemSize is the amount of memory used to hold the segment data and
// the associated metadata.
func (s *segment) segMemSize() int {
- return SegSize + s.data.Size()
+ return SegSize + s.dataMemSize
}
// parse populates the sequence & ack numbers, flags, and window fields of the
diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go
index 7dc2741a6..7422d8c02 100644
--- a/pkg/tcpip/transport/tcp/segment_state.go
+++ b/pkg/tcpip/transport/tcp/segment_state.go
@@ -24,16 +24,11 @@ import (
func (s *segment) saveData() buffer.VectorisedView {
// We cannot save s.data directly as s.data.views may alias to s.views,
// which is not allowed by state framework (in-struct pointer).
- v := make([]buffer.View, len(s.data.Views()))
- // For views already delivered, we cannot save them directly as they may
- // have already been sliced and saved elsewhere (e.g., readViews).
- for i := 0; i < s.viewToDeliver; i++ {
- v[i] = append([]byte(nil), s.data.Views()[i]...)
+ vs := make([]buffer.View, len(s.data.Views()))
+ for i, v := range s.data.Views() {
+ vs[i] = v
}
- for i := s.viewToDeliver; i < len(v); i++ {
- v[i] = s.data.Views()[i]
- }
- return buffer.NewVectorisedView(s.data.Size(), v)
+ return buffer.NewVectorisedView(s.data.Size(), vs)
}
// loadData is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index cf60d5b53..9fa4672d7 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -17,10 +17,12 @@ package tcp_test
import (
"bytes"
"fmt"
+ "io/ioutil"
"math"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/rand"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -40,6 +42,64 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
+// endpointTester provides helper functions to test a tcpip.Endpoint.
+type endpointTester struct {
+ ep tcpip.Endpoint
+}
+
+// CheckReadError issues a read to the endpoint and checking for an error.
+func (e *endpointTester) CheckReadError(t *testing.T, want *tcpip.Error) {
+ t.Helper()
+ res, got := e.ep.Read(ioutil.Discard, 1, tcpip.ReadOptions{})
+ if got != want {
+ t.Fatalf("ep.Read = %s, want %s", got, want)
+ }
+ if diff := cmp.Diff(tcpip.ReadResult{}, res); diff != "" {
+ t.Errorf("ep.Read: unexpected non-zero result (-want +got):\n%s", diff)
+ }
+}
+
+// CheckRead issues a read to the endpoint and checking for a success, returning
+// the data read.
+func (e *endpointTester) CheckRead(t *testing.T, count int) []byte {
+ t.Helper()
+ var buf bytes.Buffer
+ res, err := e.ep.Read(&buf, count, tcpip.ReadOptions{})
+ if err != nil {
+ t.Fatalf("ep.Read = _, %s; want _, nil", err)
+ }
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, res, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("ep.Read: unexpected result (-want +got):\n%s", diff)
+ }
+ return buf.Bytes()
+}
+
+// CheckReadFull reads from the endpoint for exactly count bytes.
+func (e *endpointTester) CheckReadFull(t *testing.T, count int, notifyRead <-chan struct{}, timeout time.Duration) []byte {
+ t.Helper()
+ var buf bytes.Buffer
+ var done int
+ for done < count {
+ res, err := e.ep.Read(&buf, count-done, tcpip.ReadOptions{})
+ if err == tcpip.ErrWouldBlock {
+ // Wait for receive to be notified.
+ select {
+ case <-notifyRead:
+ case <-time.After(timeout):
+ t.Fatalf("Timed out waiting for data to arrive")
+ }
+ continue
+ } else if err != nil {
+ t.Fatalf("ep.Read = _, %s; want _, nil", err)
+ }
+ done += res.Count
+ }
+ return buf.Bytes()
+}
+
const (
// defaultMTU is the MTU, in bytes, used throughout the tests, except
// where another value is explicitly used. It is chosen to match the MTU
@@ -740,9 +800,7 @@ func TestSimpleReceive(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -762,11 +820,7 @@ func TestSimpleReceive(t *testing.T) {
}
// Receive data.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
-
+ v := ept.CheckRead(t, defaultMTU)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
@@ -1492,14 +1546,11 @@ func TestSynSent(t *testing.T) {
t.Fatal("timed out waiting for packet to arrive")
}
+ ept := endpointTester{c.EP}
if test.reset {
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionRefused {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionRefused)
- }
+ ept.CheckReadError(t, tcpip.ErrConnectionRefused)
} else {
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrAborted {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrAborted)
- }
+ ept.CheckReadError(t, tcpip.ErrAborted)
}
if got := c.Stack().Stats().TCP.CurrentConnected.Value(); got != 0 {
@@ -1524,9 +1575,8 @@ func TestOutOfOrderReceive(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Send second half of data first, with seqnum 3 ahead of expected.
data := []byte{1, 2, 3, 4, 5, 6}
@@ -1551,9 +1601,7 @@ func TestOutOfOrderReceive(t *testing.T) {
// Wait 200ms and check that no data has been received.
time.Sleep(200 * time.Millisecond)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Send the first 3 bytes now.
c.SendPacket(data[:3], &context.Headers{
@@ -1566,24 +1614,7 @@ func TestOutOfOrderReceive(t *testing.T) {
})
// Receive data.
- read := make([]byte, 0, 6)
- for len(read) < len(data) {
- v, _, err := c.EP.Read(nil)
- if err != nil {
- if err == tcpip.ErrWouldBlock {
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
- continue
- }
- t.Fatalf("Read failed: %s", err)
- }
-
- read = append(read, v...)
- }
+ read := ept.CheckReadFull(t, 6, ch, 5*time.Second)
// Check that we received the data in proper order.
if !bytes.Equal(data, read) {
@@ -1608,9 +1639,8 @@ func TestOutOfOrderFlood(t *testing.T) {
rcvBufSz := math.MaxUint16
c.CreateConnected(789, 30000, rcvBufSz)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Send 100 packets before the actual one that is expected.
data := []byte{1, 2, 3, 4, 5, 6}
@@ -1685,9 +1715,8 @@ func TestRstOnCloseWithUnreadData(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -1754,9 +1783,8 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
data := []byte{1, 2, 3}
c.SendPacket(data, &context.Headers{
@@ -1837,17 +1865,14 @@ func TestShutdownRead(t *testing.T) {
c.CreateConnected(789, 30000, -1 /* epRcvBuf */)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
if err := c.EP.Shutdown(tcpip.ShutdownRead); err != nil {
t.Fatalf("Shutdown failed: %s", err)
}
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
- }
+ ept.CheckReadError(t, tcpip.ErrClosedForReceive)
var want uint64 = 1
if got := c.EP.Stats().(*tcp.Stats).ReadErrors.ReadClosed.Value(); got != want {
t.Fatalf("got EP stats Stats.ReadErrors.ReadClosed got %d want %d", got, want)
@@ -1865,10 +1890,8 @@ func TestFullWindowReceive(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- _, _, err := c.EP.Read(nil)
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %s", err)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Fill up the window w/ tcp.SegOverheadFactor*rcvBufSz as netstack multiplies
// the provided buffer value by tcp.SegOverheadFactor to calculate the actual
@@ -1905,11 +1928,7 @@ func TestFullWindowReceive(t *testing.T) {
)
// Receive data and check it.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
-
+ v := ept.CheckRead(t, defaultMTU)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
@@ -1991,8 +2010,9 @@ func TestSmallSegReceiveWindowAdvertisement(t *testing.T) {
// Read the data so that the subsequent ACK from the endpoint
// grows the right edge of the window.
- if _, _, err := c.EP.Read(nil); err != nil {
- t.Fatalf("got Read(nil) = %s", err)
+ var buf bytes.Buffer
+ if _, err := c.EP.Read(&buf, math.MaxUint16, tcpip.ReadOptions{}); err != nil {
+ t.Fatalf("c.EP.Read: %s", err)
}
// Check if we have received max uint16 as our advertised
@@ -2027,9 +2047,9 @@ func TestNoWindowShrinking(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
+
// Send a 1 byte payload so that we can record the current receive window.
// Send a payload of half the size of rcvBufSize.
seqNum := iss.Add(1)
@@ -2051,11 +2071,7 @@ func TestNoWindowShrinking(t *testing.T) {
}
// Read the 1 byte payload we just sent.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
- if got, want := payload, v; !bytes.Equal(got, want) {
+ if got, want := payload, ept.CheckRead(t, 1); !bytes.Equal(got, want) {
t.Fatalf("got data: %v, want: %v", got, want)
}
@@ -2128,24 +2144,8 @@ func TestNoWindowShrinking(t *testing.T) {
),
)
- // Wait for receive to be notified.
- select {
- case <-ch:
- case <-time.After(5 * time.Second):
- t.Fatalf("Timed out waiting for data to arrive")
- }
-
// Receive data and check it.
- read := make([]byte, 0, rcvBufSize)
- for len(read) < len(data) {
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
-
- read = append(read, v...)
- }
-
+ read := ept.CheckReadFull(t, len(data), ch, 5*time.Second)
if !bytes.Equal(data, read) {
t.Fatalf("got data = %v, want = %v", read, data)
}
@@ -2569,11 +2569,11 @@ func TestZeroScaledWindowReceive(t *testing.T) {
// we need to read at 3 packets.
sz := 0
for sz < defaultMTU*2 {
- v, _, err := c.EP.Read(nil)
+ res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Read failed: %s", err)
}
- sz += len(v)
+ sz += res.Count
}
checker.IPv4(t, c.GetPacket(),
@@ -3268,13 +3268,13 @@ func TestReceiveOnResetConnection(t *testing.T) {
loop:
for {
- switch _, _, err := c.EP.Read(nil); err {
+ switch _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err {
case tcpip.ErrWouldBlock:
select {
case <-ch:
// Expect the state to be StateError and subsequent Reads to fail with HardError.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrConnectionReset {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrConnectionReset)
+ if _, err := c.EP.Read(ioutil.Discard, math.MaxUint16, tcpip.ReadOptions{}); err != tcpip.ErrConnectionReset {
+ t.Fatalf("got c.EP.Read() = %s, want = %s", err, tcpip.ErrConnectionReset)
}
break loop
case <-time.After(1 * time.Second):
@@ -4164,9 +4164,8 @@ func TestReadAfterClosedState(t *testing.T) {
c.WQ.EventRegister(&we, waiter.EventIn)
defer c.WQ.EventUnregister(&we)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Shutdown immediately for write, check that we get a FIN.
if err := c.EP.Shutdown(tcpip.ShutdownWrite); err != nil {
@@ -4224,35 +4223,31 @@ func TestReadAfterClosedState(t *testing.T) {
}
// Check that peek works.
- peekBuf := make([]byte, 10)
- n, err := c.EP.Peek([][]byte{peekBuf})
+ var peekBuf bytes.Buffer
+ res, err := c.EP.Read(&peekBuf, 10, tcpip.ReadOptions{Peek: true})
if err != nil {
t.Fatalf("Peek failed: %s", err)
}
- peekBuf = peekBuf[:n]
- if !bytes.Equal(data, peekBuf) {
- t.Fatalf("got data = %v, want = %v", peekBuf, data)
+ if got, want := res.Count, len(data); got != want {
+ t.Fatalf("res.Count = %d, want %d", got, want)
}
-
- // Receive data.
- v, _, err := c.EP.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
+ if !bytes.Equal(data, peekBuf.Bytes()) {
+ t.Fatalf("got data = %v, want = %v", peekBuf.Bytes(), data)
}
+ // Receive data.
+ v := ept.CheckRead(t, defaultMTU)
if !bytes.Equal(data, v) {
t.Fatalf("got data = %v, want = %v", v, data)
}
// Now that we drained the queue, check that functions fail with the
// right error code.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrClosedForReceive)
- }
-
- if _, err := c.EP.Peek([][]byte{peekBuf}); err != tcpip.ErrClosedForReceive {
- t.Fatalf("got c.EP.Peek(...) = %s, want = %s", err, tcpip.ErrClosedForReceive)
+ ept.CheckReadError(t, tcpip.ErrClosedForReceive)
+ var buf bytes.Buffer
+ if _, err := c.EP.Read(&buf, 1, tcpip.ReadOptions{Peek: true}); err != tcpip.ErrClosedForReceive {
+ t.Fatalf("c.EP.Read(_, _, {Peek: true}) = %v, %s; want _, %s", res, err, tcpip.ErrClosedForReceive)
}
}
@@ -4619,17 +4614,8 @@ func TestSelfConnect(t *testing.T) {
// Read back what was written.
wq.EventUnregister(&waitEntry)
wq.EventRegister(&waitEntry, waiter.EventIn)
- rd, _, err := ep.Read(nil)
- if err != nil {
- if err != tcpip.ErrWouldBlock {
- t.Fatalf("Read failed: %s", err)
- }
- <-notifyCh
- rd, _, err = ep.Read(nil)
- if err != nil {
- t.Fatalf("Read failed: %s", err)
- }
- }
+ ept := endpointTester{ep}
+ rd := ept.CheckReadFull(t, len(data), notifyCh, 5*time.Second)
if !bytes.Equal(data, rd) {
t.Fatalf("got data = %v, want = %v", rd, data)
@@ -5082,9 +5068,8 @@ func TestKeepalive(t *testing.T) {
}
// Check that the connection is still alive.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Send some data and wait before ACKing it. Keepalives should be disabled
// during this period.
@@ -5173,9 +5158,7 @@ func TestKeepalive(t *testing.T) {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout.Value() = %d, want = 1", got)
}
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
- }
+ ept.CheckReadError(t, tcpip.ErrTimeout)
if got := c.Stack().Stats().TCP.CurrentEstablished.Value(); got != 0 {
t.Errorf("got stats.TCP.CurrentEstablished.Value() = %d, want = 0", got)
@@ -6070,9 +6053,8 @@ func TestEndpointBindListenAcceptState(t *testing.T) {
t.Errorf("unexpected endpoint state: want %s, got %s", want, got)
}
- if _, _, err := ep.Read(nil); err != tcpip.ErrNotConnected {
- t.Errorf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrNotConnected)
- }
+ ept := endpointTester{ep}
+ ept.CheckReadError(t, tcpip.ErrNotConnected)
if got := ep.Stats().(*tcp.Stats).ReadErrors.NotConnected.Value(); got != 1 {
t.Errorf("got EP stats Stats.ReadErrors.NotConnected got %d want %d", got, 1)
}
@@ -6227,7 +6209,7 @@ func TestReceiveBufferAutoTuningApplicationLimited(t *testing.T) {
// Now read all the data from the endpoint and verify that advertised
// window increases to the full available buffer size.
for {
- _, _, err := c.EP.Read(nil)
+ _, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
if err == tcpip.ErrWouldBlock {
break
}
@@ -6351,11 +6333,11 @@ func TestReceiveBufferAutoTuning(t *testing.T) {
// to happen before we measure the new window.
totalCopied := 0
for {
- b, _, err := c.EP.Read(nil)
+ res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
if err == tcpip.ErrWouldBlock {
break
}
- totalCopied += len(b)
+ totalCopied += res.Count
}
// Invoke the moderation API. This is required for auto-tuning
@@ -7272,9 +7254,8 @@ func TestTCPUserTimeout(t *testing.T) {
),
)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrTimeout)
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
@@ -7317,9 +7298,8 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
}
// Check that the connection is still alive.
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrWouldBlock {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrWouldBlock)
- }
+ ept := endpointTester{c.EP}
+ ept.CheckReadError(t, tcpip.ErrWouldBlock)
// Now receive 1 keepalives, but don't ACK it.
b := c.GetPacket()
@@ -7358,9 +7338,7 @@ func TestKeepaliveWithUserTimeout(t *testing.T) {
),
)
- if _, _, err := c.EP.Read(nil); err != tcpip.ErrTimeout {
- t.Fatalf("got c.EP.Read(nil) = %s, want = %s", err, tcpip.ErrTimeout)
- }
+ ept.CheckReadError(t, tcpip.ErrTimeout)
if got, want := c.Stack().Stats().TCP.EstablishedTimedout.Value(), origEstablishedTimedout+1; got != want {
t.Errorf("got c.Stack().Stats().TCP.EstablishedTimedout = %d, want = %d", got, want)
}
@@ -7417,11 +7395,11 @@ func TestIncreaseWindowOnRead(t *testing.T) {
// defaultMTU is a good enough estimate for the MSS used for this
// connection.
for read < defaultMTU*2 {
- v, _, err := c.EP.Read(nil)
+ res, err := c.EP.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Read failed: %s", err)
}
- read += len(v)
+ read += res.Count
}
// After reading > MSS worth of data, we surely crossed MSS. See the ack:
diff --git a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
index 0f9ed06cd..9e02d467d 100644
--- a/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_timestamp_test.go
@@ -20,6 +20,7 @@ import (
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -105,11 +106,18 @@ func TestTimeStampEnabledConnect(t *testing.T) {
// There should be 5 views to read and each of them should
// contain the same data.
for i := 0; i < 5; i++ {
- got, _, err := c.EP.Read(nil)
+ var buf bytes.Buffer
+ result, err := c.EP.Read(&buf, len(data), tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Unexpected error from Read: %v", err)
}
- if want := data; bytes.Compare(got, want) != 0 {
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 {
t.Fatalf("Data is different: got: %v, want: %v", got, want)
}
}
@@ -286,11 +294,18 @@ func TestSegmentNotDroppedWhenTimestampMissing(t *testing.T) {
}
// Issue a read and we should data.
- got, _, err := c.EP.Read(nil)
+ var buf bytes.Buffer
+ result, err := c.EP.Read(&buf, defaultMTU, tcpip.ReadOptions{})
if err != nil {
t.Fatalf("Unexpected error from Read: %v", err)
}
- if want := data; bytes.Compare(got, want) != 0 {
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ }, result, checker.IgnoreCmpPath("ControlMessages")); diff != "" {
+ t.Errorf("Read: unexpected result (-want +got):\n%s", diff)
+ }
+ if got, want := buf.Bytes(), data; bytes.Compare(got, want) != 0 {
t.Fatalf("Data is different: got: %v, want: %v", got, want)
}
}
diff --git a/pkg/tcpip/transport/udp/BUILD b/pkg/tcpip/transport/udp/BUILD
index 7ebae63d8..153e8c950 100644
--- a/pkg/tcpip/transport/udp/BUILD
+++ b/pkg/tcpip/transport/udp/BUILD
@@ -58,5 +58,6 @@ go_test(
"//pkg/tcpip/stack",
"//pkg/tcpip/transport/icmp",
"//pkg/waiter",
+ "@com_github_google_go_cmp//cmp:go_default_library",
],
)
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 4e8bd8b04..075de1db0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -16,6 +16,7 @@ package udp
import (
"fmt"
+ "io"
"sync/atomic"
"gvisor.dev/gvisor/pkg/sync"
@@ -282,11 +283,10 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
-// Read reads data from the endpoint. This method does not block if
-// there is no data pending.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
if err := e.LastError(); err != nil {
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
e.rcvMu.Lock()
@@ -298,18 +298,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
p := e.rcvList.Front()
- e.rcvList.Remove(p)
- e.rcvBufSize -= p.data.Size()
- e.rcvMu.Unlock()
-
- if addr != nil {
- *addr = p.senderAddress
+ if !opts.Peek {
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
}
+ e.rcvMu.Unlock()
+ // Control Messages
cm := tcpip.ControlMessages{
HasTimestamp: true,
Timestamp: p.timestamp,
@@ -331,7 +330,22 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
cm.HasOriginalDstAddress = true
cm.OriginalDstAddress = p.destinationAddress
}
- return p.data.ToView(), cm, nil
+
+ // Read Result
+ res := tcpip.ReadResult{
+ Total: p.data.Size(),
+ ControlMessages: cm,
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = p.senderAddress
+ }
+
+ n, err := p.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -566,11 +580,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return int64(len(v)), nil, nil
}
-// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
func (e *endpoint) OnReuseAddressSet(v bool) {
e.mu.Lock()
diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go
index 8429f34b4..455b8c2aa 100644
--- a/pkg/tcpip/transport/udp/udp_test.go
+++ b/pkg/tcpip/transport/udp/udp_test.go
@@ -18,10 +18,12 @@ import (
"bytes"
"context"
"fmt"
+ "io/ioutil"
"math/rand"
"testing"
"time"
+ "github.com/google/go-cmp/cmp"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/checker"
@@ -595,13 +597,13 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
// Take a snapshot of the stats to validate them at the end of the test.
epstats := c.ep.Stats().(*tcpip.TransportEndpointStats).Clone()
- var addr tcpip.FullAddress
- v, cm, err := c.ep.Read(&addr)
+ var buf bytes.Buffer
+ res, err := c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true})
if err == tcpip.ErrWouldBlock {
// Wait for data to become available.
select {
case <-ch:
- v, cm, err = c.ep.Read(&addr)
+ res, err = c.ep.Read(&buf, defaultMTU, tcpip.ReadOptions{NeedRemoteAddr: true})
case <-time.After(300 * time.Millisecond):
if packetShouldBeDropped {
@@ -621,23 +623,32 @@ func testReadInternal(c *testContext, flow testFlow, packetShouldBeDropped, expe
}
if packetShouldBeDropped {
- c.t.Fatalf("Read unexpectedly received data from %s", addr.Addr)
+ c.t.Fatalf("Read unexpectedly received data from %s", res.RemoteAddr.Addr)
}
- // Check the peer address.
+ // Check the read result.
h := flow.header4Tuple(incoming)
- if addr.Addr != h.srcAddr.Addr {
- c.t.Fatalf("got address = %s, want = %s", addr.Addr, h.srcAddr.Addr)
+ if diff := cmp.Diff(tcpip.ReadResult{
+ Count: buf.Len(),
+ Total: buf.Len(),
+ RemoteAddr: tcpip.FullAddress{Addr: h.srcAddr.Addr},
+ }, res, checker.IgnoreCmpPath(
+ "ControlMessages", // ControlMessages will be checked later.
+ "RemoteAddr.NIC",
+ "RemoteAddr.Port",
+ )); diff != "" {
+ c.t.Fatalf("Read: unexpected result (-want +got):\n%s", diff)
}
// Check the payload.
+ v := buf.Bytes()
if !bytes.Equal(payload, v) {
c.t.Fatalf("got payload = %x, want = %x", v, payload)
}
// Run any checkers against the ControlMessages.
for _, f := range checkers {
- f(c.t, cm)
+ f(c.t, res.ControlMessages)
}
c.checkEndpointReadStats(1, epstats, err)
@@ -828,8 +839,8 @@ func TestV4ReadSelfSource(t *testing.T) {
t.Errorf("c.s.Stats().IP.InvalidSourceAddressesReceived got %d, want %d", got, tt.wantInvalidSource)
}
- if _, _, err := c.ep.Read(nil); err != tt.wantErr {
- t.Errorf("got c.ep.Read(nil) = %s, want = %s", err, tt.wantErr)
+ if _, err := c.ep.Read(ioutil.Discard, defaultMTU, tcpip.ReadOptions{}); err != tt.wantErr {
+ t.Errorf("got c.ep.Read = %s, want = %s", err, tt.wantErr)
}
})
}