summaryrefslogtreecommitdiffhomepage
path: root/pkg/tcpip
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/tcpip')
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go57
-rw-r--r--pkg/tcpip/buffer/view.go37
-rw-r--r--pkg/tcpip/tcpip.go81
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go38
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go46
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go34
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go228
-rw-r--r--pkg/tcpip/transport/tcp/segment.go16
-rw-r--r--pkg/tcpip/transport/tcp/segment_state.go13
-rw-r--r--pkg/tcpip/transport/tcp/tcp_state_autogen.go58
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go43
11 files changed, 377 insertions, 274 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 4f551cd92..7193f56ad 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -286,45 +286,47 @@ type opErrorer interface {
// commonRead implements the common logic between net.Conn.Read and
// net.PacketConn.ReadFrom.
-func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer, dontWait bool) ([]byte, error) {
+func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) (int, error) {
select {
case <-deadline:
- return nil, errorer.newOpError("read", &timeoutError{})
+ return 0, errorer.newOpError("read", &timeoutError{})
default:
}
- read, _, err := ep.Read(addr)
+ w := tcpip.SliceWriter(b)
+ opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil}
+ res, err := ep.Read(&w, len(b), opts)
if err == tcpip.ErrWouldBlock {
- if dontWait {
- return nil, errWouldBlock
- }
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&waitEntry, waiter.EventIn)
defer wq.EventUnregister(&waitEntry)
for {
- read, _, err = ep.Read(addr)
+ res, err = ep.Read(&w, len(b), opts)
if err != tcpip.ErrWouldBlock {
break
}
select {
case <-deadline:
- return nil, errorer.newOpError("read", &timeoutError{})
+ return 0, errorer.newOpError("read", &timeoutError{})
case <-notifyCh:
}
}
}
if err == tcpip.ErrClosedForReceive {
- return nil, io.EOF
+ return 0, io.EOF
}
if err != nil {
- return nil, errorer.newOpError("read", errors.New(err.String()))
+ return 0, errorer.newOpError("read", errors.New(err.String()))
}
- return read, nil
+ if addr != nil {
+ *addr = res.RemoteAddr
+ }
+ return res.Count, nil
}
// Read implements net.Conn.Read.
@@ -334,31 +336,11 @@ func (c *TCPConn) Read(b []byte) (int, error) {
deadline := c.readCancel()
- numRead := 0
- defer func() {
- if numRead != 0 {
- c.ep.ModerateRecvBuf(numRead)
- }
- }()
- for numRead != len(b) {
- if len(c.read) == 0 {
- var err error
- c.read, err = commonRead(c.ep, c.wq, deadline, nil, c, numRead != 0)
- if err != nil {
- if numRead != 0 {
- return numRead, nil
- }
- return numRead, err
- }
- }
- n := copy(b[numRead:], c.read)
- c.read.TrimFront(n)
- numRead += n
- if len(c.read) == 0 {
- c.read = nil
- }
+ n, err := commonRead(b, c.ep, c.wq, deadline, nil, c)
+ if n != 0 {
+ c.ep.ModerateRecvBuf(n)
}
- return numRead, nil
+ return n, err
}
// Write implements net.Conn.Write.
@@ -652,12 +634,11 @@ func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) {
deadline := c.readCancel()
var addr tcpip.FullAddress
- read, err := commonRead(c.ep, c.wq, deadline, &addr, c, false)
+ n, err := commonRead(b, c.ep, c.wq, deadline, &addr, c)
if err != nil {
return 0, nil, err
}
-
- return copy(b, read), fullToUDPAddr(addr), nil
+ return n, fullToUDPAddr(addr), nil
}
func (c *UDPConn) Write(b []byte) (int, error) {
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 8db70a700..5dd1b1b6b 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -105,18 +105,18 @@ func (vv *VectorisedView) TrimFront(count int) {
}
// Read implements io.Reader.
-func (vv *VectorisedView) Read(v View) (copied int, err error) {
- count := len(v)
+func (vv *VectorisedView) Read(b []byte) (copied int, err error) {
+ count := len(b)
for count > 0 && len(vv.views) > 0 {
if count < len(vv.views[0]) {
vv.size -= count
- copy(v[copied:], vv.views[0][:count])
+ copy(b[copied:], vv.views[0][:count])
vv.views[0].TrimFront(count)
copied += count
return copied, nil
}
count -= len(vv.views[0])
- copy(v[copied:], vv.views[0])
+ copy(b[copied:], vv.views[0])
copied += len(vv.views[0])
vv.removeFirst()
}
@@ -145,6 +145,35 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int
return copied
}
+// ReadTo reads up to count bytes from vv to dst. It also removes them from vv
+// unless peek is true.
+func (vv *VectorisedView) ReadTo(dst io.Writer, count int, peek bool) (int, error) {
+ var err error
+ done := 0
+ for _, v := range vv.Views() {
+ remaining := count - done
+ if remaining <= 0 {
+ break
+ }
+ if len(v) > remaining {
+ v = v[:remaining]
+ }
+
+ var n int
+ n, err = dst.Write(v)
+ if n > 0 {
+ done += n
+ }
+ if err != nil {
+ break
+ }
+ }
+ if !peek {
+ vv.TrimFront(done)
+ }
+ return done, err
+}
+
// CapLength irreversibly reduces the length of the vectorised view.
func (vv *VectorisedView) CapLength(length int) {
if length < 0 {
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index ef0f51f1a..f798056c0 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -31,6 +31,7 @@ package tcpip
import (
"errors"
"fmt"
+ "io"
"math/bits"
"reflect"
"strconv"
@@ -39,7 +40,6 @@ import (
"time"
"gvisor.dev/gvisor/pkg/sync"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/waiter"
)
@@ -113,6 +113,7 @@ var (
ErrNotPermitted = &Error{msg: "operation not permitted"}
ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"}
ErrMalformedHeader = &Error{msg: "header is malformed"}
+ ErrBadBuffer = &Error{msg: "bad buffer"}
)
var messageToError map[string]*Error
@@ -162,6 +163,7 @@ func StringToError(s string) *Error {
ErrNotPermitted,
ErrAddressFamilyNotSupported,
ErrMalformedHeader,
+ ErrBadBuffer,
}
messageToError = make(map[string]*Error)
@@ -496,6 +498,21 @@ func (s SlicePayload) Payload(size int) ([]byte, *Error) {
return s[:size], nil
}
+var _ io.Writer = (*SliceWriter)(nil)
+
+// SliceWriter implements io.Writer for slices.
+type SliceWriter []byte
+
+// Write implements io.Writer.Write.
+func (s *SliceWriter) Write(b []byte) (int, error) {
+ n := copy(*s, b)
+ *s = (*s)[n:]
+ if n < len(b) {
+ return n, io.ErrShortWrite
+ }
+ return n, nil
+}
+
// A ControlMessages contains socket control messages for IP sockets.
//
// +stateify savable
@@ -552,6 +569,40 @@ type PacketOwner interface {
GID() uint32
}
+// ReadOptions contains options for Endpoint.Read.
+type ReadOptions struct {
+ // Peek indicates whether this read is a peek.
+ Peek bool
+
+ // NeedRemoteAddr indicates whether to return the remote address, if
+ // supported.
+ NeedRemoteAddr bool
+
+ // NeedLinkPacketInfo indicates whether to return the link-layer information,
+ // if supported.
+ NeedLinkPacketInfo bool
+}
+
+// ReadResult represents result for a successful Endpoint.Read.
+type ReadResult struct {
+ // Count is the number of bytes received and written to the buffer.
+ Count int
+
+ // Total is the number of bytes of the received packet. This can be used to
+ // determine whether the read is truncated.
+ Total int
+
+ // ControlMessages is the control messages received.
+ ControlMessages ControlMessages
+
+ // RemoteAddr is the remote address if ReadOptions.NeedAddr is true.
+ RemoteAddr FullAddress
+
+ // LinkPacketInfo is the link-layer information of the received packet if
+ // ReadOptions.NeedLinkPacketInfo is true.
+ LinkPacketInfo LinkPacketInfo
+}
+
// Endpoint is the interface implemented by transport protocols (e.g., tcp, udp)
// that exposes functionality like read, write, connect, etc. to users of the
// networking stack.
@@ -566,11 +617,15 @@ type Endpoint interface {
// Abort is best effort; implementing Abort with Close is acceptable.
Abort()
- // Read reads data from the endpoint and optionally returns the sender.
+ // Read reads data from the endpoint and optionally writes to dst.
+ //
+ // This method does not block if there is no data pending; in this case,
+ // ErrWouldBlock is returned.
//
- // This method does not block if there is no data pending. It will also
- // either return an error or data, never both.
- Read(*FullAddress) (buffer.View, ControlMessages, *Error)
+ // If non-zero number of bytes are successfully read and written to dst, err
+ // must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer
+ // should be returned.
+ Read(dst io.Writer, count int, opts ReadOptions) (res ReadResult, err *Error)
// Write writes data to the endpoint's peer. This method does not block if
// the data cannot be written.
@@ -592,11 +647,6 @@ type Endpoint interface {
// not). The channel is only non-nil in this case.
Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error)
- // Peek reads data without consuming it from the endpoint.
- //
- // This method does not block if there is no data pending.
- Peek([][]byte) (int64, *Error)
-
// Connect connects the endpoint to its peer. Specifying a NIC is
// optional.
//
@@ -703,17 +753,6 @@ type LinkPacketInfo struct {
PktType PacketType
}
-// PacketEndpoint are additional methods that are only implemented by Packet
-// endpoints.
-type PacketEndpoint interface {
- // ReadPacket reads a datagram/packet from the endpoint and optionally
- // returns the sender and additional LinkPacketInfo.
- //
- // This method does not block if there is no data pending. It will also
- // either return an error or data, never both.
- ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error)
-}
-
// EndpointInfo is the interface implemented by each endpoint info struct.
type EndpointInfo interface {
// IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index d1e4a7cb7..2eb4457df 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -15,6 +15,8 @@
package icmp
import (
+ "io"
+
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -151,9 +153,8 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-// Read reads data from the endpoint. This method does not block if
-// there is no data pending.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvMu.Lock()
if e.rcvList.Empty() {
@@ -163,20 +164,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
p := e.rcvList.Front()
- e.rcvList.Remove(p)
- e.rcvBufSize -= p.data.Size()
+ if !opts.Peek {
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
+ }
e.rcvMu.Unlock()
- if addr != nil {
- *addr = p.senderAddress
+ res := tcpip.ReadResult{
+ Total: p.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: p.timestamp,
+ },
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = p.senderAddress
}
- return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil
+ n, err := p.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -329,11 +344,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return int64(len(v)), nil, nil
}
-// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// SetSockOpt sets a socket option.
func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
return nil
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index e5e247342..3ab060751 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -26,6 +26,7 @@ package packet
import (
"fmt"
+ "io"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -160,8 +161,8 @@ func (ep *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (ep *endpoint) ModerateRecvBuf(copied int) {}
-// Read implements tcpip.PacketEndpoint.ReadPacket.
-func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -173,29 +174,37 @@ func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketIn
err = tcpip.ErrClosedForReceive
}
ep.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
packet := ep.rcvList.Front()
- ep.rcvList.Remove(packet)
- ep.rcvBufSize -= packet.data.Size()
+ if !opts.Peek {
+ ep.rcvList.Remove(packet)
+ ep.rcvBufSize -= packet.data.Size()
+ }
ep.rcvMu.Unlock()
- if addr != nil {
- *addr = packet.senderAddr
+ res := tcpip.ReadResult{
+ Total: packet.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: packet.timestampNS,
+ },
}
-
- if info != nil {
- *info = packet.packetInfo
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = packet.senderAddr
+ }
+ if opts.NeedLinkPacketInfo {
+ res.LinkPacketInfo = packet.packetInfo
}
- return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil
-}
-
-// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- return ep.ReadPacket(addr, nil)
+ n, err := packet.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
@@ -203,11 +212,6 @@ func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-cha
return 0, nil, tcpip.ErrInvalidOptionValue
}
-// Peek implements tcpip.Endpoint.Peek.
-func (*endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be
// disconnected, and this function always returns tpcip.ErrNotSupported.
func (*endpoint) Disconnect() *tcpip.Error {
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 7befcfc9b..dd260535f 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -27,6 +27,7 @@ package raw
import (
"fmt"
+ "io"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -190,7 +191,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -202,20 +203,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
pkt := e.rcvList.Front()
- e.rcvList.Remove(pkt)
- e.rcvBufSize -= pkt.data.Size()
+ if !opts.Peek {
+ e.rcvList.Remove(pkt)
+ e.rcvBufSize -= pkt.data.Size()
+ }
e.rcvMu.Unlock()
- if addr != nil {
- *addr = pkt.senderAddr
+ res := tcpip.ReadResult{
+ Total: pkt.data.Size(),
+ ControlMessages: tcpip.ControlMessages{
+ HasTimestamp: true,
+ Timestamp: pkt.timestampNS,
+ },
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = pkt.senderAddr
}
- return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil
+ n, err := pkt.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// Write implements tcpip.Endpoint.Write.
@@ -363,11 +378,6 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64,
return int64(len(payloadBytes)), nil, nil
}
-// Peek implements tcpip.Endpoint.Peek.
-func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// Disconnect implements tcpip.Endpoint.Disconnect.
func (*endpoint) Disconnect() *tcpip.Error {
return tcpip.ErrNotSupported
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index 6e3c8860e..8f3981075 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -17,6 +17,7 @@ package tcp
import (
"encoding/binary"
"fmt"
+ "io"
"math"
"runtime"
"strings"
@@ -27,7 +28,6 @@ import (
"gvisor.dev/gvisor/pkg/sleep"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/buffer"
"gvisor.dev/gvisor/pkg/tcpip/hash/jenkins"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/ports"
@@ -393,15 +393,28 @@ type endpoint struct {
lastErrorMu sync.Mutex `state:"nosave"`
lastError *tcpip.Error `state:".(string)"`
- // The following fields are used to manage the receive queue. The
- // protocol goroutine adds ready-for-delivery segments to rcvList,
- // which are returned by Read() calls to users.
+ // rcvReadMu synchronizes calls to Read.
//
- // Once the peer has closed its send side, rcvClosed is set to true
- // to indicate to users that no more data is coming.
+ // mu and rcvListMu are temporarily released during data copying. rcvReadMu
+ // must be held during each read to ensure atomicity, so that multiple reads
+ // do not interleave.
+ //
+ // rcvReadMu should be held before holding mu.
+ rcvReadMu sync.Mutex `state:"nosave"`
+
+ // rcvListMu synchronizes access to rcvList.
//
// rcvListMu can be taken after the endpoint mu below.
- rcvListMu sync.Mutex `state:"nosave"`
+ rcvListMu sync.Mutex `state:"nosave"`
+
+ // rcvList is the queue for ready-for-delivery segments.
+ //
+ // rcvReadMu, mu and rcvListMu must be held, in the stated order, to read data
+ // and removing segments from list. A range of segment can be determined, then
+ // temporarily release mu and rcvListMu while processing the segment range.
+ // This allows new segments to be appended to the list while processing.
+ //
+ // rcvListMu must be held to append segments to list.
rcvList segmentList `state:"wait"`
rcvClosed bool
// rcvBufSize is the total size of the receive buffer.
@@ -1309,8 +1322,69 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) {
e.UnlockUser()
}
-// Read reads data from the endpoint.
-func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+ e.rcvReadMu.Lock()
+ defer e.rcvReadMu.Unlock()
+
+ // N.B. Here we get a range of segments to be processed. It is safe to not
+ // hold rcvListMu when processing, since we hold rcvReadMu to ensure only we
+ // can remove segments from the list through commitRead().
+ first, last, serr := e.startRead()
+ if serr != nil {
+ if serr == tcpip.ErrClosedForReceive {
+ e.stats.ReadErrors.ReadClosed.Increment()
+ }
+ return tcpip.ReadResult{}, serr
+ }
+
+ var err error
+ done := 0
+ s := first
+ for s != nil && done < count {
+ var n int
+ n, err = s.data.ReadTo(dst, count-done, opts.Peek)
+ // Book keeping first then error handling.
+
+ done += n
+
+ if opts.Peek {
+ // For peek, we use the (first, last) range of segment returned from
+ // startRead. We don't consume the receive buffer, so commitRead should
+ // not be called.
+ //
+ // N.B. It is important to use `last` to determine the last segment, since
+ // appending can happen while we process, and will lead to data race.
+ if s == last {
+ break
+ }
+ s = s.Next()
+ } else {
+ // N.B. commitRead() conveniently returns the next segment to read, after
+ // removing the data/segment that is read.
+ s = e.commitRead(n)
+ }
+
+ if err != nil {
+ break
+ }
+ }
+
+ // If something is read, we must report it. Report error when nothing is read.
+ if done == 0 && err != nil {
+ return tcpip.ReadResult{}, tcpip.ErrBadBuffer
+ }
+ return tcpip.ReadResult{
+ Count: done,
+ Total: done,
+ }, nil
+}
+
+// startRead checks that endpoint is in a readable state, and return the
+// inclusive range of segments that can be read.
+//
+// Precondition: e.rcvReadMu must be held.
+func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) {
e.LockUser()
defer e.UnlockUser()
@@ -1319,7 +1393,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
// on a receive. It can expect to read any data after the handshake
// is complete. RFC793, section 3.9, p58.
if e.EndpointState() == StateSynSent {
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock
+ return nil, nil, tcpip.ErrWouldBlock
}
// The endpoint can be read if it's connected, or if it's already closed
@@ -1327,61 +1401,69 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages,
// would cause the state to become StateError so we should allow the
// reads to proceed before returning a ECONNRESET.
e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
+
bufUsed := e.rcvBufUsed
if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 {
- e.rcvListMu.Unlock()
if s == StateError {
if err := e.hardErrorLocked(); err != nil {
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return nil, nil, err
}
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive
+ return nil, nil, tcpip.ErrClosedForReceive
}
e.stats.ReadErrors.NotConnected.Increment()
- return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected
+ return nil, nil, tcpip.ErrNotConnected
}
- v, err := e.readLocked()
- e.rcvListMu.Unlock()
-
- if err == tcpip.ErrClosedForReceive {
- e.stats.ReadErrors.ReadClosed.Increment()
- }
- return v, tcpip.ControlMessages{}, err
-}
-
-func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) {
if e.rcvBufUsed == 0 {
if e.rcvClosed || !e.EndpointState().connected() {
- return buffer.View{}, tcpip.ErrClosedForReceive
+ return nil, nil, tcpip.ErrClosedForReceive
}
- return buffer.View{}, tcpip.ErrWouldBlock
+ return nil, nil, tcpip.ErrWouldBlock
}
- s := e.rcvList.Front()
- views := s.data.Views()
- v := views[s.viewToDeliver]
- s.viewToDeliver++
+ return e.rcvList.Front(), e.rcvList.Back(), nil
+}
+
+// commitRead commits a read of done bytes and returns the next non-empty
+// segment to read. Data read from the segment must have also been removed from
+// the segment in order for this method to work correctly.
+//
+// It is performance critical to call commitRead frequently when servicing a big
+// Read request, so TCP can make progress timely. Right now, it is designed to
+// do this per segment read, hence this method conveniently returns the next
+// segment to read while holding the lock.
+//
+// Precondition: e.rcvReadMu must be held.
+func (e *endpoint) commitRead(done int) *segment {
+ e.LockUser()
+ defer e.UnlockUser()
+ e.rcvListMu.Lock()
+ defer e.rcvListMu.Unlock()
- var delta int
- if s.viewToDeliver >= len(views) {
+ memDelta := 0
+ s := e.rcvList.Front()
+ for s != nil && s.data.Size() == 0 {
e.rcvList.Remove(s)
- // We only free up receive buffer space when the segment is released as the
- // segment is still holding on to the views even though some views have been
- // read out to the user.
- delta = s.segMemSize()
+ // Memory is only considered released when the whole segment has been
+ // read.
+ memDelta += s.segMemSize()
s.decRef()
+ s = e.rcvList.Front()
}
+ e.rcvBufUsed -= done
- e.rcvBufUsed -= len(v)
- // If the window was small before this read and if the read freed up
- // enough buffer space, to either fit an aMSS or half a receive buffer
- // (whichever smaller), then notify the protocol goroutine to send a
- // window update.
- if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above {
- e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ if memDelta > 0 {
+ // If the window was small before this read and if the read freed up
+ // enough buffer space, to either fit an aMSS or half a receive buffer
+ // (whichever smaller), then notify the protocol goroutine to send a
+ // window update.
+ if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above {
+ e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow)
+ }
}
- return v, nil
+ return e.rcvList.Front()
}
// isEndpointWritableLocked checks if a given endpoint is writable
@@ -1499,64 +1581,6 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return queueAndSend()
}
-// Peek reads data without consuming it from the endpoint.
-//
-// This method does not block if there is no data pending.
-func (e *endpoint) Peek(vec [][]byte) (int64, *tcpip.Error) {
- e.LockUser()
- defer e.UnlockUser()
-
- // The endpoint can be read if it's connected, or if it's already closed
- // but has some pending unread data.
- if s := e.EndpointState(); !s.connected() && s != StateClose {
- if s == StateError {
- return 0, e.hardErrorLocked()
- }
- e.stats.ReadErrors.InvalidEndpointState.Increment()
- return 0, tcpip.ErrInvalidEndpointState
- }
-
- e.rcvListMu.Lock()
- defer e.rcvListMu.Unlock()
-
- if e.rcvBufUsed == 0 {
- if e.rcvClosed || !e.EndpointState().connected() {
- e.stats.ReadErrors.ReadClosed.Increment()
- return 0, tcpip.ErrClosedForReceive
- }
- return 0, tcpip.ErrWouldBlock
- }
-
- // Make a copy of vec so we can modify the slide headers.
- vec = append([][]byte(nil), vec...)
-
- var num int64
- for s := e.rcvList.Front(); s != nil; s = s.Next() {
- views := s.data.Views()
-
- for i := s.viewToDeliver; i < len(views); i++ {
- v := views[i]
-
- for len(v) > 0 {
- if len(vec) == 0 {
- return num, nil
- }
- if len(vec[0]) == 0 {
- vec = vec[1:]
- continue
- }
-
- n := copy(vec[0], v)
- v = v[n:]
- vec[0] = vec[0][n:]
- num += int64(n)
- }
- }
- }
-
- return num, nil
-}
-
// selectWindowLocked returns the new window without checking for shrinking or scaling
// applied.
// Precondition: e.mu and e.rcvListMu must be held.
diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go
index 5ef73ec74..c5a6d2fba 100644
--- a/pkg/tcpip/transport/tcp/segment.go
+++ b/pkg/tcpip/transport/tcp/segment.go
@@ -37,7 +37,7 @@ const (
// segment represents a TCP segment. It holds the payload and parsed TCP segment
// information, and can be added to intrusive lists.
-// segment is mostly immutable, the only field allowed to change is viewToDeliver.
+// segment is mostly immutable, the only field allowed to change is data.
//
// +stateify savable
type segment struct {
@@ -60,10 +60,7 @@ type segment struct {
hdr header.TCP
// views is used as buffer for data when its length is large
// enough to store a VectorisedView.
- views [8]buffer.View `state:"nosave"`
- // viewToDeliver keeps track of the next View that should be
- // delivered by the Read endpoint.
- viewToDeliver int
+ views [8]buffer.View `state:"nosave"`
sequenceNumber seqnum.Value
ackNumber seqnum.Value
flags uint8
@@ -84,6 +81,9 @@ type segment struct {
// acked indicates if the segment has already been SACKed.
acked bool
+
+ // dataMemSize is the memory used by data initially.
+ dataMemSize int
}
func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment {
@@ -100,6 +100,7 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *
s.data = pkt.Data.Clone(s.views[:])
s.hdr = header.TCP(pkt.TransportHeader().View())
s.rcvdTime = time.Now()
+ s.dataMemSize = s.data.Size()
return s
}
@@ -113,6 +114,7 @@ func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment {
s.views[0] = v
s.data = buffer.NewVectorisedView(len(v), s.views[:1])
}
+ s.dataMemSize = s.data.Size()
return s
}
@@ -127,12 +129,12 @@ func (s *segment) clone() *segment {
netProto: s.netProto,
nicID: s.nicID,
remoteLinkAddr: s.remoteLinkAddr,
- viewToDeliver: s.viewToDeliver,
rcvdTime: s.rcvdTime,
xmitTime: s.xmitTime,
xmitCount: s.xmitCount,
ep: s.ep,
qFlags: s.qFlags,
+ dataMemSize: s.dataMemSize,
}
t.data = s.data.Clone(t.views[:])
return t
@@ -204,7 +206,7 @@ func (s *segment) payloadSize() int {
// segMemSize is the amount of memory used to hold the segment data and
// the associated metadata.
func (s *segment) segMemSize() int {
- return SegSize + s.data.Size()
+ return SegSize + s.dataMemSize
}
// parse populates the sequence & ack numbers, flags, and window fields of the
diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go
index 7dc2741a6..7422d8c02 100644
--- a/pkg/tcpip/transport/tcp/segment_state.go
+++ b/pkg/tcpip/transport/tcp/segment_state.go
@@ -24,16 +24,11 @@ import (
func (s *segment) saveData() buffer.VectorisedView {
// We cannot save s.data directly as s.data.views may alias to s.views,
// which is not allowed by state framework (in-struct pointer).
- v := make([]buffer.View, len(s.data.Views()))
- // For views already delivered, we cannot save them directly as they may
- // have already been sliced and saved elsewhere (e.g., readViews).
- for i := 0; i < s.viewToDeliver; i++ {
- v[i] = append([]byte(nil), s.data.Views()[i]...)
+ vs := make([]buffer.View, len(s.data.Views()))
+ for i, v := range s.data.Views() {
+ vs[i] = v
}
- for i := s.viewToDeliver; i < len(v); i++ {
- v[i] = s.data.Views()[i]
- }
- return buffer.NewVectorisedView(s.data.Size(), v)
+ return buffer.NewVectorisedView(s.data.Size(), vs)
}
// loadData is invoked by stateify.
diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
index 5922083a9..aab92b94f 100644
--- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go
+++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go
@@ -595,7 +595,6 @@ func (s *segment) StateFields() []string {
"remoteLinkAddr",
"data",
"hdr",
- "viewToDeliver",
"sequenceNumber",
"ackNumber",
"flags",
@@ -609,6 +608,7 @@ func (s *segment) StateFields() []string {
"xmitTime",
"xmitCount",
"acked",
+ "dataMemSize",
}
}
@@ -619,11 +619,11 @@ func (s *segment) StateSave(stateSinkObject state.Sink) {
var dataValue buffer.VectorisedView = s.saveData()
stateSinkObject.SaveValue(9, dataValue)
var optionsValue []byte = s.saveOptions()
- stateSinkObject.SaveValue(19, optionsValue)
+ stateSinkObject.SaveValue(18, optionsValue)
var rcvdTimeValue unixTime = s.saveRcvdTime()
- stateSinkObject.SaveValue(21, rcvdTimeValue)
+ stateSinkObject.SaveValue(20, rcvdTimeValue)
var xmitTimeValue unixTime = s.saveXmitTime()
- stateSinkObject.SaveValue(22, xmitTimeValue)
+ stateSinkObject.SaveValue(21, xmitTimeValue)
stateSinkObject.Save(0, &s.segmentEntry)
stateSinkObject.Save(1, &s.refCnt)
stateSinkObject.Save(2, &s.ep)
@@ -634,17 +634,17 @@ func (s *segment) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(7, &s.nicID)
stateSinkObject.Save(8, &s.remoteLinkAddr)
stateSinkObject.Save(10, &s.hdr)
- stateSinkObject.Save(11, &s.viewToDeliver)
- stateSinkObject.Save(12, &s.sequenceNumber)
- stateSinkObject.Save(13, &s.ackNumber)
- stateSinkObject.Save(14, &s.flags)
- stateSinkObject.Save(15, &s.window)
- stateSinkObject.Save(16, &s.csum)
- stateSinkObject.Save(17, &s.csumValid)
- stateSinkObject.Save(18, &s.parsedOptions)
- stateSinkObject.Save(20, &s.hasNewSACKInfo)
- stateSinkObject.Save(23, &s.xmitCount)
- stateSinkObject.Save(24, &s.acked)
+ stateSinkObject.Save(11, &s.sequenceNumber)
+ stateSinkObject.Save(12, &s.ackNumber)
+ stateSinkObject.Save(13, &s.flags)
+ stateSinkObject.Save(14, &s.window)
+ stateSinkObject.Save(15, &s.csum)
+ stateSinkObject.Save(16, &s.csumValid)
+ stateSinkObject.Save(17, &s.parsedOptions)
+ stateSinkObject.Save(19, &s.hasNewSACKInfo)
+ stateSinkObject.Save(22, &s.xmitCount)
+ stateSinkObject.Save(23, &s.acked)
+ stateSinkObject.Save(24, &s.dataMemSize)
}
func (s *segment) afterLoad() {}
@@ -660,21 +660,21 @@ func (s *segment) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(7, &s.nicID)
stateSourceObject.Load(8, &s.remoteLinkAddr)
stateSourceObject.Load(10, &s.hdr)
- stateSourceObject.Load(11, &s.viewToDeliver)
- stateSourceObject.Load(12, &s.sequenceNumber)
- stateSourceObject.Load(13, &s.ackNumber)
- stateSourceObject.Load(14, &s.flags)
- stateSourceObject.Load(15, &s.window)
- stateSourceObject.Load(16, &s.csum)
- stateSourceObject.Load(17, &s.csumValid)
- stateSourceObject.Load(18, &s.parsedOptions)
- stateSourceObject.Load(20, &s.hasNewSACKInfo)
- stateSourceObject.Load(23, &s.xmitCount)
- stateSourceObject.Load(24, &s.acked)
+ stateSourceObject.Load(11, &s.sequenceNumber)
+ stateSourceObject.Load(12, &s.ackNumber)
+ stateSourceObject.Load(13, &s.flags)
+ stateSourceObject.Load(14, &s.window)
+ stateSourceObject.Load(15, &s.csum)
+ stateSourceObject.Load(16, &s.csumValid)
+ stateSourceObject.Load(17, &s.parsedOptions)
+ stateSourceObject.Load(19, &s.hasNewSACKInfo)
+ stateSourceObject.Load(22, &s.xmitCount)
+ stateSourceObject.Load(23, &s.acked)
+ stateSourceObject.Load(24, &s.dataMemSize)
stateSourceObject.LoadValue(9, new(buffer.VectorisedView), func(y interface{}) { s.loadData(y.(buffer.VectorisedView)) })
- stateSourceObject.LoadValue(19, new([]byte), func(y interface{}) { s.loadOptions(y.([]byte)) })
- stateSourceObject.LoadValue(21, new(unixTime), func(y interface{}) { s.loadRcvdTime(y.(unixTime)) })
- stateSourceObject.LoadValue(22, new(unixTime), func(y interface{}) { s.loadXmitTime(y.(unixTime)) })
+ stateSourceObject.LoadValue(18, new([]byte), func(y interface{}) { s.loadOptions(y.([]byte)) })
+ stateSourceObject.LoadValue(20, new(unixTime), func(y interface{}) { s.loadRcvdTime(y.(unixTime)) })
+ stateSourceObject.LoadValue(21, new(unixTime), func(y interface{}) { s.loadXmitTime(y.(unixTime)) })
}
func (q *segmentQueue) StateTypeName() string {
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 4e8bd8b04..075de1db0 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -16,6 +16,7 @@ package udp
import (
"fmt"
+ "io"
"sync/atomic"
"gvisor.dev/gvisor/pkg/sync"
@@ -282,11 +283,10 @@ func (e *endpoint) Close() {
// ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf.
func (e *endpoint) ModerateRecvBuf(copied int) {}
-// Read reads data from the endpoint. This method does not block if
-// there is no data pending.
-func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
+// Read implements tcpip.Endpoint.Read.
+func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
if err := e.LastError(); err != nil {
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
e.rcvMu.Lock()
@@ -298,18 +298,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
err = tcpip.ErrClosedForReceive
}
e.rcvMu.Unlock()
- return buffer.View{}, tcpip.ControlMessages{}, err
+ return tcpip.ReadResult{}, err
}
p := e.rcvList.Front()
- e.rcvList.Remove(p)
- e.rcvBufSize -= p.data.Size()
- e.rcvMu.Unlock()
-
- if addr != nil {
- *addr = p.senderAddress
+ if !opts.Peek {
+ e.rcvList.Remove(p)
+ e.rcvBufSize -= p.data.Size()
}
+ e.rcvMu.Unlock()
+ // Control Messages
cm := tcpip.ControlMessages{
HasTimestamp: true,
Timestamp: p.timestamp,
@@ -331,7 +330,22 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess
cm.HasOriginalDstAddress = true
cm.OriginalDstAddress = p.destinationAddress
}
- return p.data.ToView(), cm, nil
+
+ // Read Result
+ res := tcpip.ReadResult{
+ Total: p.data.Size(),
+ ControlMessages: cm,
+ }
+ if opts.NeedRemoteAddr {
+ res.RemoteAddr = p.senderAddress
+ }
+
+ n, err := p.data.ReadTo(dst, count, opts.Peek)
+ if n == 0 && err != nil {
+ return res, tcpip.ErrBadBuffer
+ }
+ res.Count = n
+ return res, nil
}
// prepareForWrite prepares the endpoint for sending data. In particular, it
@@ -566,11 +580,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
return int64(len(v)), nil, nil
}
-// Peek only returns data from a single datagram, so do nothing here.
-func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) {
- return 0, nil
-}
-
// OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet.
func (e *endpoint) OnReuseAddressSet(v bool) {
e.mu.Lock()