diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet.go | 57 | ||||
-rw-r--r-- | pkg/tcpip/buffer/view.go | 37 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 81 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 38 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 46 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 34 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 228 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/segment.go | 16 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/segment_state.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/tcp_state_autogen.go | 58 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 43 |
11 files changed, 377 insertions, 274 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 4f551cd92..7193f56ad 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -286,45 +286,47 @@ type opErrorer interface { // commonRead implements the common logic between net.Conn.Read and // net.PacketConn.ReadFrom. -func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer, dontWait bool) ([]byte, error) { +func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) (int, error) { select { case <-deadline: - return nil, errorer.newOpError("read", &timeoutError{}) + return 0, errorer.newOpError("read", &timeoutError{}) default: } - read, _, err := ep.Read(addr) + w := tcpip.SliceWriter(b) + opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil} + res, err := ep.Read(&w, len(b), opts) if err == tcpip.ErrWouldBlock { - if dontWait { - return nil, errWouldBlock - } // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) wq.EventRegister(&waitEntry, waiter.EventIn) defer wq.EventUnregister(&waitEntry) for { - read, _, err = ep.Read(addr) + res, err = ep.Read(&w, len(b), opts) if err != tcpip.ErrWouldBlock { break } select { case <-deadline: - return nil, errorer.newOpError("read", &timeoutError{}) + return 0, errorer.newOpError("read", &timeoutError{}) case <-notifyCh: } } } if err == tcpip.ErrClosedForReceive { - return nil, io.EOF + return 0, io.EOF } if err != nil { - return nil, errorer.newOpError("read", errors.New(err.String())) + return 0, errorer.newOpError("read", errors.New(err.String())) } - return read, nil + if addr != nil { + *addr = res.RemoteAddr + } + return res.Count, nil } // Read implements net.Conn.Read. @@ -334,31 +336,11 @@ func (c *TCPConn) Read(b []byte) (int, error) { deadline := c.readCancel() - numRead := 0 - defer func() { - if numRead != 0 { - c.ep.ModerateRecvBuf(numRead) - } - }() - for numRead != len(b) { - if len(c.read) == 0 { - var err error - c.read, err = commonRead(c.ep, c.wq, deadline, nil, c, numRead != 0) - if err != nil { - if numRead != 0 { - return numRead, nil - } - return numRead, err - } - } - n := copy(b[numRead:], c.read) - c.read.TrimFront(n) - numRead += n - if len(c.read) == 0 { - c.read = nil - } + n, err := commonRead(b, c.ep, c.wq, deadline, nil, c) + if n != 0 { + c.ep.ModerateRecvBuf(n) } - return numRead, nil + return n, err } // Write implements net.Conn.Write. @@ -652,12 +634,11 @@ func (c *UDPConn) ReadFrom(b []byte) (int, net.Addr, error) { deadline := c.readCancel() var addr tcpip.FullAddress - read, err := commonRead(c.ep, c.wq, deadline, &addr, c, false) + n, err := commonRead(b, c.ep, c.wq, deadline, &addr, c) if err != nil { return 0, nil, err } - - return copy(b, read), fullToUDPAddr(addr), nil + return n, fullToUDPAddr(addr), nil } func (c *UDPConn) Write(b []byte) (int, error) { diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 8db70a700..5dd1b1b6b 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -105,18 +105,18 @@ func (vv *VectorisedView) TrimFront(count int) { } // Read implements io.Reader. -func (vv *VectorisedView) Read(v View) (copied int, err error) { - count := len(v) +func (vv *VectorisedView) Read(b []byte) (copied int, err error) { + count := len(b) for count > 0 && len(vv.views) > 0 { if count < len(vv.views[0]) { vv.size -= count - copy(v[copied:], vv.views[0][:count]) + copy(b[copied:], vv.views[0][:count]) vv.views[0].TrimFront(count) copied += count return copied, nil } count -= len(vv.views[0]) - copy(v[copied:], vv.views[0]) + copy(b[copied:], vv.views[0]) copied += len(vv.views[0]) vv.removeFirst() } @@ -145,6 +145,35 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int return copied } +// ReadTo reads up to count bytes from vv to dst. It also removes them from vv +// unless peek is true. +func (vv *VectorisedView) ReadTo(dst io.Writer, count int, peek bool) (int, error) { + var err error + done := 0 + for _, v := range vv.Views() { + remaining := count - done + if remaining <= 0 { + break + } + if len(v) > remaining { + v = v[:remaining] + } + + var n int + n, err = dst.Write(v) + if n > 0 { + done += n + } + if err != nil { + break + } + } + if !peek { + vv.TrimFront(done) + } + return done, err +} + // CapLength irreversibly reduces the length of the vectorised view. func (vv *VectorisedView) CapLength(length int) { if length < 0 { diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index ef0f51f1a..f798056c0 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -31,6 +31,7 @@ package tcpip import ( "errors" "fmt" + "io" "math/bits" "reflect" "strconv" @@ -39,7 +40,6 @@ import ( "time" "gvisor.dev/gvisor/pkg/sync" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/waiter" ) @@ -113,6 +113,7 @@ var ( ErrNotPermitted = &Error{msg: "operation not permitted"} ErrAddressFamilyNotSupported = &Error{msg: "address family not supported by protocol"} ErrMalformedHeader = &Error{msg: "header is malformed"} + ErrBadBuffer = &Error{msg: "bad buffer"} ) var messageToError map[string]*Error @@ -162,6 +163,7 @@ func StringToError(s string) *Error { ErrNotPermitted, ErrAddressFamilyNotSupported, ErrMalformedHeader, + ErrBadBuffer, } messageToError = make(map[string]*Error) @@ -496,6 +498,21 @@ func (s SlicePayload) Payload(size int) ([]byte, *Error) { return s[:size], nil } +var _ io.Writer = (*SliceWriter)(nil) + +// SliceWriter implements io.Writer for slices. +type SliceWriter []byte + +// Write implements io.Writer.Write. +func (s *SliceWriter) Write(b []byte) (int, error) { + n := copy(*s, b) + *s = (*s)[n:] + if n < len(b) { + return n, io.ErrShortWrite + } + return n, nil +} + // A ControlMessages contains socket control messages for IP sockets. // // +stateify savable @@ -552,6 +569,40 @@ type PacketOwner interface { GID() uint32 } +// ReadOptions contains options for Endpoint.Read. +type ReadOptions struct { + // Peek indicates whether this read is a peek. + Peek bool + + // NeedRemoteAddr indicates whether to return the remote address, if + // supported. + NeedRemoteAddr bool + + // NeedLinkPacketInfo indicates whether to return the link-layer information, + // if supported. + NeedLinkPacketInfo bool +} + +// ReadResult represents result for a successful Endpoint.Read. +type ReadResult struct { + // Count is the number of bytes received and written to the buffer. + Count int + + // Total is the number of bytes of the received packet. This can be used to + // determine whether the read is truncated. + Total int + + // ControlMessages is the control messages received. + ControlMessages ControlMessages + + // RemoteAddr is the remote address if ReadOptions.NeedAddr is true. + RemoteAddr FullAddress + + // LinkPacketInfo is the link-layer information of the received packet if + // ReadOptions.NeedLinkPacketInfo is true. + LinkPacketInfo LinkPacketInfo +} + // Endpoint is the interface implemented by transport protocols (e.g., tcp, udp) // that exposes functionality like read, write, connect, etc. to users of the // networking stack. @@ -566,11 +617,15 @@ type Endpoint interface { // Abort is best effort; implementing Abort with Close is acceptable. Abort() - // Read reads data from the endpoint and optionally returns the sender. + // Read reads data from the endpoint and optionally writes to dst. + // + // This method does not block if there is no data pending; in this case, + // ErrWouldBlock is returned. // - // This method does not block if there is no data pending. It will also - // either return an error or data, never both. - Read(*FullAddress) (buffer.View, ControlMessages, *Error) + // If non-zero number of bytes are successfully read and written to dst, err + // must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer + // should be returned. + Read(dst io.Writer, count int, opts ReadOptions) (res ReadResult, err *Error) // Write writes data to the endpoint's peer. This method does not block if // the data cannot be written. @@ -592,11 +647,6 @@ type Endpoint interface { // not). The channel is only non-nil in this case. Write(Payloader, WriteOptions) (int64, <-chan struct{}, *Error) - // Peek reads data without consuming it from the endpoint. - // - // This method does not block if there is no data pending. - Peek([][]byte) (int64, *Error) - // Connect connects the endpoint to its peer. Specifying a NIC is // optional. // @@ -703,17 +753,6 @@ type LinkPacketInfo struct { PktType PacketType } -// PacketEndpoint are additional methods that are only implemented by Packet -// endpoints. -type PacketEndpoint interface { - // ReadPacket reads a datagram/packet from the endpoint and optionally - // returns the sender and additional LinkPacketInfo. - // - // This method does not block if there is no data pending. It will also - // either return an error or data, never both. - ReadPacket(*FullAddress, *LinkPacketInfo) (buffer.View, ControlMessages, *Error) -} - // EndpointInfo is the interface implemented by each endpoint info struct. type EndpointInfo interface { // IsEndpointInfo is an empty method to implement the tcpip.EndpointInfo diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index d1e4a7cb7..2eb4457df 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -15,6 +15,8 @@ package icmp import ( + "io" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -151,9 +153,8 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -// Read reads data from the endpoint. This method does not block if -// there is no data pending. -func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.Endpoint.Read. +func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { e.rcvMu.Lock() if e.rcvList.Empty() { @@ -163,20 +164,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() - return buffer.View{}, tcpip.ControlMessages{}, err + return tcpip.ReadResult{}, err } p := e.rcvList.Front() - e.rcvList.Remove(p) - e.rcvBufSize -= p.data.Size() + if !opts.Peek { + e.rcvList.Remove(p) + e.rcvBufSize -= p.data.Size() + } e.rcvMu.Unlock() - if addr != nil { - *addr = p.senderAddress + res := tcpip.ReadResult{ + Total: p.data.Size(), + ControlMessages: tcpip.ControlMessages{ + HasTimestamp: true, + Timestamp: p.timestamp, + }, + } + if opts.NeedRemoteAddr { + res.RemoteAddr = p.senderAddress } - return p.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: p.timestamp}, nil + n, err := p.data.ReadTo(dst, count, opts.Peek) + if n == 0 && err != nil { + return res, tcpip.ErrBadBuffer + } + res.Count = n + return res, nil } // prepareForWrite prepares the endpoint for sending data. In particular, it @@ -329,11 +344,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return int64(len(v)), nil, nil } -// Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { - return 0, nil -} - // SetSockOpt sets a socket option. func (e *endpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { return nil diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index e5e247342..3ab060751 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -26,6 +26,7 @@ package packet import ( "fmt" + "io" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -160,8 +161,8 @@ func (ep *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (ep *endpoint) ModerateRecvBuf(copied int) {} -// Read implements tcpip.PacketEndpoint.ReadPacket. -func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketInfo) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.Endpoint.Read. +func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -173,29 +174,37 @@ func (ep *endpoint) ReadPacket(addr *tcpip.FullAddress, info *tcpip.LinkPacketIn err = tcpip.ErrClosedForReceive } ep.rcvMu.Unlock() - return buffer.View{}, tcpip.ControlMessages{}, err + return tcpip.ReadResult{}, err } packet := ep.rcvList.Front() - ep.rcvList.Remove(packet) - ep.rcvBufSize -= packet.data.Size() + if !opts.Peek { + ep.rcvList.Remove(packet) + ep.rcvBufSize -= packet.data.Size() + } ep.rcvMu.Unlock() - if addr != nil { - *addr = packet.senderAddr + res := tcpip.ReadResult{ + Total: packet.data.Size(), + ControlMessages: tcpip.ControlMessages{ + HasTimestamp: true, + Timestamp: packet.timestampNS, + }, } - - if info != nil { - *info = packet.packetInfo + if opts.NeedRemoteAddr { + res.RemoteAddr = packet.senderAddr + } + if opts.NeedLinkPacketInfo { + res.LinkPacketInfo = packet.packetInfo } - return packet.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: packet.timestampNS}, nil -} - -// Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - return ep.ReadPacket(addr, nil) + n, err := packet.data.ReadTo(dst, count, opts.Peek) + if n == 0 && err != nil { + return res, tcpip.ErrBadBuffer + } + res.Count = n + return res, nil } func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { @@ -203,11 +212,6 @@ func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-cha return 0, nil, tcpip.ErrInvalidOptionValue } -// Peek implements tcpip.Endpoint.Peek. -func (*endpoint) Peek([][]byte) (int64, *tcpip.Error) { - return 0, nil -} - // Disconnect implements tcpip.Endpoint.Disconnect. Packet sockets cannot be // disconnected, and this function always returns tpcip.ErrNotSupported. func (*endpoint) Disconnect() *tcpip.Error { diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 7befcfc9b..dd260535f 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -27,6 +27,7 @@ package raw import ( "fmt" + "io" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" @@ -190,7 +191,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { e.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -202,20 +203,34 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() - return buffer.View{}, tcpip.ControlMessages{}, err + return tcpip.ReadResult{}, err } pkt := e.rcvList.Front() - e.rcvList.Remove(pkt) - e.rcvBufSize -= pkt.data.Size() + if !opts.Peek { + e.rcvList.Remove(pkt) + e.rcvBufSize -= pkt.data.Size() + } e.rcvMu.Unlock() - if addr != nil { - *addr = pkt.senderAddr + res := tcpip.ReadResult{ + Total: pkt.data.Size(), + ControlMessages: tcpip.ControlMessages{ + HasTimestamp: true, + Timestamp: pkt.timestampNS, + }, + } + if opts.NeedRemoteAddr { + res.RemoteAddr = pkt.senderAddr } - return pkt.data.ToView(), tcpip.ControlMessages{HasTimestamp: true, Timestamp: pkt.timestampNS}, nil + n, err := pkt.data.ReadTo(dst, count, opts.Peek) + if n == 0 && err != nil { + return res, tcpip.ErrBadBuffer + } + res.Count = n + return res, nil } // Write implements tcpip.Endpoint.Write. @@ -363,11 +378,6 @@ func (e *endpoint) finishWrite(payloadBytes []byte, route *stack.Route) (int64, return int64(len(payloadBytes)), nil, nil } -// Peek implements tcpip.Endpoint.Peek. -func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { - return 0, nil -} - // Disconnect implements tcpip.Endpoint.Disconnect. func (*endpoint) Disconnect() *tcpip.Error { return tcpip.ErrNotSupported diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 6e3c8860e..8f3981075 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -17,6 +17,7 @@ package tcp import ( "encoding/binary" "fmt" + "io" "math" "runtime" "strings" @@ -27,7 +28,6 @@ import ( "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/hash/jenkins" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/ports" @@ -393,15 +393,28 @@ type endpoint struct { lastErrorMu sync.Mutex `state:"nosave"` lastError *tcpip.Error `state:".(string)"` - // The following fields are used to manage the receive queue. The - // protocol goroutine adds ready-for-delivery segments to rcvList, - // which are returned by Read() calls to users. + // rcvReadMu synchronizes calls to Read. // - // Once the peer has closed its send side, rcvClosed is set to true - // to indicate to users that no more data is coming. + // mu and rcvListMu are temporarily released during data copying. rcvReadMu + // must be held during each read to ensure atomicity, so that multiple reads + // do not interleave. + // + // rcvReadMu should be held before holding mu. + rcvReadMu sync.Mutex `state:"nosave"` + + // rcvListMu synchronizes access to rcvList. // // rcvListMu can be taken after the endpoint mu below. - rcvListMu sync.Mutex `state:"nosave"` + rcvListMu sync.Mutex `state:"nosave"` + + // rcvList is the queue for ready-for-delivery segments. + // + // rcvReadMu, mu and rcvListMu must be held, in the stated order, to read data + // and removing segments from list. A range of segment can be determined, then + // temporarily release mu and rcvListMu while processing the segment range. + // This allows new segments to be appended to the list while processing. + // + // rcvListMu must be held to append segments to list. rcvList segmentList `state:"wait"` rcvClosed bool // rcvBufSize is the total size of the receive buffer. @@ -1309,8 +1322,69 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) { e.UnlockUser() } -// Read reads data from the endpoint. -func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.Endpoint.Read. +func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { + e.rcvReadMu.Lock() + defer e.rcvReadMu.Unlock() + + // N.B. Here we get a range of segments to be processed. It is safe to not + // hold rcvListMu when processing, since we hold rcvReadMu to ensure only we + // can remove segments from the list through commitRead(). + first, last, serr := e.startRead() + if serr != nil { + if serr == tcpip.ErrClosedForReceive { + e.stats.ReadErrors.ReadClosed.Increment() + } + return tcpip.ReadResult{}, serr + } + + var err error + done := 0 + s := first + for s != nil && done < count { + var n int + n, err = s.data.ReadTo(dst, count-done, opts.Peek) + // Book keeping first then error handling. + + done += n + + if opts.Peek { + // For peek, we use the (first, last) range of segment returned from + // startRead. We don't consume the receive buffer, so commitRead should + // not be called. + // + // N.B. It is important to use `last` to determine the last segment, since + // appending can happen while we process, and will lead to data race. + if s == last { + break + } + s = s.Next() + } else { + // N.B. commitRead() conveniently returns the next segment to read, after + // removing the data/segment that is read. + s = e.commitRead(n) + } + + if err != nil { + break + } + } + + // If something is read, we must report it. Report error when nothing is read. + if done == 0 && err != nil { + return tcpip.ReadResult{}, tcpip.ErrBadBuffer + } + return tcpip.ReadResult{ + Count: done, + Total: done, + }, nil +} + +// startRead checks that endpoint is in a readable state, and return the +// inclusive range of segments that can be read. +// +// Precondition: e.rcvReadMu must be held. +func (e *endpoint) startRead() (first, last *segment, err *tcpip.Error) { e.LockUser() defer e.UnlockUser() @@ -1319,7 +1393,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, // on a receive. It can expect to read any data after the handshake // is complete. RFC793, section 3.9, p58. if e.EndpointState() == StateSynSent { - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrWouldBlock + return nil, nil, tcpip.ErrWouldBlock } // The endpoint can be read if it's connected, or if it's already closed @@ -1327,61 +1401,69 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, // would cause the state to become StateError so we should allow the // reads to proceed before returning a ECONNRESET. e.rcvListMu.Lock() + defer e.rcvListMu.Unlock() + bufUsed := e.rcvBufUsed if s := e.EndpointState(); !s.connected() && s != StateClose && bufUsed == 0 { - e.rcvListMu.Unlock() if s == StateError { if err := e.hardErrorLocked(); err != nil { - return buffer.View{}, tcpip.ControlMessages{}, err + return nil, nil, err } - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive + return nil, nil, tcpip.ErrClosedForReceive } e.stats.ReadErrors.NotConnected.Increment() - return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrNotConnected + return nil, nil, tcpip.ErrNotConnected } - v, err := e.readLocked() - e.rcvListMu.Unlock() - - if err == tcpip.ErrClosedForReceive { - e.stats.ReadErrors.ReadClosed.Increment() - } - return v, tcpip.ControlMessages{}, err -} - -func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { if e.rcvBufUsed == 0 { if e.rcvClosed || !e.EndpointState().connected() { - return buffer.View{}, tcpip.ErrClosedForReceive + return nil, nil, tcpip.ErrClosedForReceive } - return buffer.View{}, tcpip.ErrWouldBlock + return nil, nil, tcpip.ErrWouldBlock } - s := e.rcvList.Front() - views := s.data.Views() - v := views[s.viewToDeliver] - s.viewToDeliver++ + return e.rcvList.Front(), e.rcvList.Back(), nil +} + +// commitRead commits a read of done bytes and returns the next non-empty +// segment to read. Data read from the segment must have also been removed from +// the segment in order for this method to work correctly. +// +// It is performance critical to call commitRead frequently when servicing a big +// Read request, so TCP can make progress timely. Right now, it is designed to +// do this per segment read, hence this method conveniently returns the next +// segment to read while holding the lock. +// +// Precondition: e.rcvReadMu must be held. +func (e *endpoint) commitRead(done int) *segment { + e.LockUser() + defer e.UnlockUser() + e.rcvListMu.Lock() + defer e.rcvListMu.Unlock() - var delta int - if s.viewToDeliver >= len(views) { + memDelta := 0 + s := e.rcvList.Front() + for s != nil && s.data.Size() == 0 { e.rcvList.Remove(s) - // We only free up receive buffer space when the segment is released as the - // segment is still holding on to the views even though some views have been - // read out to the user. - delta = s.segMemSize() + // Memory is only considered released when the whole segment has been + // read. + memDelta += s.segMemSize() s.decRef() + s = e.rcvList.Front() } + e.rcvBufUsed -= done - e.rcvBufUsed -= len(v) - // If the window was small before this read and if the read freed up - // enough buffer space, to either fit an aMSS or half a receive buffer - // (whichever smaller), then notify the protocol goroutine to send a - // window update. - if crossed, above := e.windowCrossedACKThresholdLocked(delta); crossed && above { - e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) + if memDelta > 0 { + // If the window was small before this read and if the read freed up + // enough buffer space, to either fit an aMSS or half a receive buffer + // (whichever smaller), then notify the protocol goroutine to send a + // window update. + if crossed, above := e.windowCrossedACKThresholdLocked(memDelta); crossed && above { + e.notifyProtocolGoroutine(notifyNonZeroReceiveWindow) + } } - return v, nil + return e.rcvList.Front() } // isEndpointWritableLocked checks if a given endpoint is writable @@ -1499,64 +1581,6 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return queueAndSend() } -// Peek reads data without consuming it from the endpoint. -// -// This method does not block if there is no data pending. -func (e *endpoint) Peek(vec [][]byte) (int64, *tcpip.Error) { - e.LockUser() - defer e.UnlockUser() - - // The endpoint can be read if it's connected, or if it's already closed - // but has some pending unread data. - if s := e.EndpointState(); !s.connected() && s != StateClose { - if s == StateError { - return 0, e.hardErrorLocked() - } - e.stats.ReadErrors.InvalidEndpointState.Increment() - return 0, tcpip.ErrInvalidEndpointState - } - - e.rcvListMu.Lock() - defer e.rcvListMu.Unlock() - - if e.rcvBufUsed == 0 { - if e.rcvClosed || !e.EndpointState().connected() { - e.stats.ReadErrors.ReadClosed.Increment() - return 0, tcpip.ErrClosedForReceive - } - return 0, tcpip.ErrWouldBlock - } - - // Make a copy of vec so we can modify the slide headers. - vec = append([][]byte(nil), vec...) - - var num int64 - for s := e.rcvList.Front(); s != nil; s = s.Next() { - views := s.data.Views() - - for i := s.viewToDeliver; i < len(views); i++ { - v := views[i] - - for len(v) > 0 { - if len(vec) == 0 { - return num, nil - } - if len(vec[0]) == 0 { - vec = vec[1:] - continue - } - - n := copy(vec[0], v) - v = v[n:] - vec[0] = vec[0][n:] - num += int64(n) - } - } - } - - return num, nil -} - // selectWindowLocked returns the new window without checking for shrinking or scaling // applied. // Precondition: e.mu and e.rcvListMu must be held. diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 5ef73ec74..c5a6d2fba 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -37,7 +37,7 @@ const ( // segment represents a TCP segment. It holds the payload and parsed TCP segment // information, and can be added to intrusive lists. -// segment is mostly immutable, the only field allowed to change is viewToDeliver. +// segment is mostly immutable, the only field allowed to change is data. // // +stateify savable type segment struct { @@ -60,10 +60,7 @@ type segment struct { hdr header.TCP // views is used as buffer for data when its length is large // enough to store a VectorisedView. - views [8]buffer.View `state:"nosave"` - // viewToDeliver keeps track of the next View that should be - // delivered by the Read endpoint. - viewToDeliver int + views [8]buffer.View `state:"nosave"` sequenceNumber seqnum.Value ackNumber seqnum.Value flags uint8 @@ -84,6 +81,9 @@ type segment struct { // acked indicates if the segment has already been SACKed. acked bool + + // dataMemSize is the memory used by data initially. + dataMemSize int } func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { @@ -100,6 +100,7 @@ func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) * s.data = pkt.Data.Clone(s.views[:]) s.hdr = header.TCP(pkt.TransportHeader().View()) s.rcvdTime = time.Now() + s.dataMemSize = s.data.Size() return s } @@ -113,6 +114,7 @@ func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment { s.views[0] = v s.data = buffer.NewVectorisedView(len(v), s.views[:1]) } + s.dataMemSize = s.data.Size() return s } @@ -127,12 +129,12 @@ func (s *segment) clone() *segment { netProto: s.netProto, nicID: s.nicID, remoteLinkAddr: s.remoteLinkAddr, - viewToDeliver: s.viewToDeliver, rcvdTime: s.rcvdTime, xmitTime: s.xmitTime, xmitCount: s.xmitCount, ep: s.ep, qFlags: s.qFlags, + dataMemSize: s.dataMemSize, } t.data = s.data.Clone(t.views[:]) return t @@ -204,7 +206,7 @@ func (s *segment) payloadSize() int { // segMemSize is the amount of memory used to hold the segment data and // the associated metadata. func (s *segment) segMemSize() int { - return SegSize + s.data.Size() + return SegSize + s.dataMemSize } // parse populates the sequence & ack numbers, flags, and window fields of the diff --git a/pkg/tcpip/transport/tcp/segment_state.go b/pkg/tcpip/transport/tcp/segment_state.go index 7dc2741a6..7422d8c02 100644 --- a/pkg/tcpip/transport/tcp/segment_state.go +++ b/pkg/tcpip/transport/tcp/segment_state.go @@ -24,16 +24,11 @@ import ( func (s *segment) saveData() buffer.VectorisedView { // We cannot save s.data directly as s.data.views may alias to s.views, // which is not allowed by state framework (in-struct pointer). - v := make([]buffer.View, len(s.data.Views())) - // For views already delivered, we cannot save them directly as they may - // have already been sliced and saved elsewhere (e.g., readViews). - for i := 0; i < s.viewToDeliver; i++ { - v[i] = append([]byte(nil), s.data.Views()[i]...) + vs := make([]buffer.View, len(s.data.Views())) + for i, v := range s.data.Views() { + vs[i] = v } - for i := s.viewToDeliver; i < len(v); i++ { - v[i] = s.data.Views()[i] - } - return buffer.NewVectorisedView(s.data.Size(), v) + return buffer.NewVectorisedView(s.data.Size(), vs) } // loadData is invoked by stateify. diff --git a/pkg/tcpip/transport/tcp/tcp_state_autogen.go b/pkg/tcpip/transport/tcp/tcp_state_autogen.go index 5922083a9..aab92b94f 100644 --- a/pkg/tcpip/transport/tcp/tcp_state_autogen.go +++ b/pkg/tcpip/transport/tcp/tcp_state_autogen.go @@ -595,7 +595,6 @@ func (s *segment) StateFields() []string { "remoteLinkAddr", "data", "hdr", - "viewToDeliver", "sequenceNumber", "ackNumber", "flags", @@ -609,6 +608,7 @@ func (s *segment) StateFields() []string { "xmitTime", "xmitCount", "acked", + "dataMemSize", } } @@ -619,11 +619,11 @@ func (s *segment) StateSave(stateSinkObject state.Sink) { var dataValue buffer.VectorisedView = s.saveData() stateSinkObject.SaveValue(9, dataValue) var optionsValue []byte = s.saveOptions() - stateSinkObject.SaveValue(19, optionsValue) + stateSinkObject.SaveValue(18, optionsValue) var rcvdTimeValue unixTime = s.saveRcvdTime() - stateSinkObject.SaveValue(21, rcvdTimeValue) + stateSinkObject.SaveValue(20, rcvdTimeValue) var xmitTimeValue unixTime = s.saveXmitTime() - stateSinkObject.SaveValue(22, xmitTimeValue) + stateSinkObject.SaveValue(21, xmitTimeValue) stateSinkObject.Save(0, &s.segmentEntry) stateSinkObject.Save(1, &s.refCnt) stateSinkObject.Save(2, &s.ep) @@ -634,17 +634,17 @@ func (s *segment) StateSave(stateSinkObject state.Sink) { stateSinkObject.Save(7, &s.nicID) stateSinkObject.Save(8, &s.remoteLinkAddr) stateSinkObject.Save(10, &s.hdr) - stateSinkObject.Save(11, &s.viewToDeliver) - stateSinkObject.Save(12, &s.sequenceNumber) - stateSinkObject.Save(13, &s.ackNumber) - stateSinkObject.Save(14, &s.flags) - stateSinkObject.Save(15, &s.window) - stateSinkObject.Save(16, &s.csum) - stateSinkObject.Save(17, &s.csumValid) - stateSinkObject.Save(18, &s.parsedOptions) - stateSinkObject.Save(20, &s.hasNewSACKInfo) - stateSinkObject.Save(23, &s.xmitCount) - stateSinkObject.Save(24, &s.acked) + stateSinkObject.Save(11, &s.sequenceNumber) + stateSinkObject.Save(12, &s.ackNumber) + stateSinkObject.Save(13, &s.flags) + stateSinkObject.Save(14, &s.window) + stateSinkObject.Save(15, &s.csum) + stateSinkObject.Save(16, &s.csumValid) + stateSinkObject.Save(17, &s.parsedOptions) + stateSinkObject.Save(19, &s.hasNewSACKInfo) + stateSinkObject.Save(22, &s.xmitCount) + stateSinkObject.Save(23, &s.acked) + stateSinkObject.Save(24, &s.dataMemSize) } func (s *segment) afterLoad() {} @@ -660,21 +660,21 @@ func (s *segment) StateLoad(stateSourceObject state.Source) { stateSourceObject.Load(7, &s.nicID) stateSourceObject.Load(8, &s.remoteLinkAddr) stateSourceObject.Load(10, &s.hdr) - stateSourceObject.Load(11, &s.viewToDeliver) - stateSourceObject.Load(12, &s.sequenceNumber) - stateSourceObject.Load(13, &s.ackNumber) - stateSourceObject.Load(14, &s.flags) - stateSourceObject.Load(15, &s.window) - stateSourceObject.Load(16, &s.csum) - stateSourceObject.Load(17, &s.csumValid) - stateSourceObject.Load(18, &s.parsedOptions) - stateSourceObject.Load(20, &s.hasNewSACKInfo) - stateSourceObject.Load(23, &s.xmitCount) - stateSourceObject.Load(24, &s.acked) + stateSourceObject.Load(11, &s.sequenceNumber) + stateSourceObject.Load(12, &s.ackNumber) + stateSourceObject.Load(13, &s.flags) + stateSourceObject.Load(14, &s.window) + stateSourceObject.Load(15, &s.csum) + stateSourceObject.Load(16, &s.csumValid) + stateSourceObject.Load(17, &s.parsedOptions) + stateSourceObject.Load(19, &s.hasNewSACKInfo) + stateSourceObject.Load(22, &s.xmitCount) + stateSourceObject.Load(23, &s.acked) + stateSourceObject.Load(24, &s.dataMemSize) stateSourceObject.LoadValue(9, new(buffer.VectorisedView), func(y interface{}) { s.loadData(y.(buffer.VectorisedView)) }) - stateSourceObject.LoadValue(19, new([]byte), func(y interface{}) { s.loadOptions(y.([]byte)) }) - stateSourceObject.LoadValue(21, new(unixTime), func(y interface{}) { s.loadRcvdTime(y.(unixTime)) }) - stateSourceObject.LoadValue(22, new(unixTime), func(y interface{}) { s.loadXmitTime(y.(unixTime)) }) + stateSourceObject.LoadValue(18, new([]byte), func(y interface{}) { s.loadOptions(y.([]byte)) }) + stateSourceObject.LoadValue(20, new(unixTime), func(y interface{}) { s.loadRcvdTime(y.(unixTime)) }) + stateSourceObject.LoadValue(21, new(unixTime), func(y interface{}) { s.loadXmitTime(y.(unixTime)) }) } func (q *segmentQueue) StateTypeName() string { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 4e8bd8b04..075de1db0 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -16,6 +16,7 @@ package udp import ( "fmt" + "io" "sync/atomic" "gvisor.dev/gvisor/pkg/sync" @@ -282,11 +283,10 @@ func (e *endpoint) Close() { // ModerateRecvBuf implements tcpip.Endpoint.ModerateRecvBuf. func (e *endpoint) ModerateRecvBuf(copied int) {} -// Read reads data from the endpoint. This method does not block if -// there is no data pending. -func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { +// Read implements tcpip.Endpoint.Read. +func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { if err := e.LastError(); err != nil { - return buffer.View{}, tcpip.ControlMessages{}, err + return tcpip.ReadResult{}, err } e.rcvMu.Lock() @@ -298,18 +298,17 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess err = tcpip.ErrClosedForReceive } e.rcvMu.Unlock() - return buffer.View{}, tcpip.ControlMessages{}, err + return tcpip.ReadResult{}, err } p := e.rcvList.Front() - e.rcvList.Remove(p) - e.rcvBufSize -= p.data.Size() - e.rcvMu.Unlock() - - if addr != nil { - *addr = p.senderAddress + if !opts.Peek { + e.rcvList.Remove(p) + e.rcvBufSize -= p.data.Size() } + e.rcvMu.Unlock() + // Control Messages cm := tcpip.ControlMessages{ HasTimestamp: true, Timestamp: p.timestamp, @@ -331,7 +330,22 @@ func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMess cm.HasOriginalDstAddress = true cm.OriginalDstAddress = p.destinationAddress } - return p.data.ToView(), cm, nil + + // Read Result + res := tcpip.ReadResult{ + Total: p.data.Size(), + ControlMessages: cm, + } + if opts.NeedRemoteAddr { + res.RemoteAddr = p.senderAddress + } + + n, err := p.data.ReadTo(dst, count, opts.Peek) + if n == 0 && err != nil { + return res, tcpip.ErrBadBuffer + } + res.Count = n + return res, nil } // prepareForWrite prepares the endpoint for sending data. In particular, it @@ -566,11 +580,6 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c return int64(len(v)), nil, nil } -// Peek only returns data from a single datagram, so do nothing here. -func (e *endpoint) Peek([][]byte) (int64, *tcpip.Error) { - return 0, nil -} - // OnReuseAddressSet implements tcpip.SocketOptionsHandler.OnReuseAddressSet. func (e *endpoint) OnReuseAddressSet(v bool) { e.mu.Lock() |