diff options
Diffstat (limited to 'pkg/tcpip')
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet.go | 67 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 28 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 7 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 2 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 13 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 6 |
7 files changed, 58 insertions, 71 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index fdeec12d3..7c7495c30 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -16,6 +16,7 @@ package gonet import ( + "bytes" "context" "errors" "io" @@ -354,8 +355,6 @@ func (c *TCPConn) Write(b []byte) (int, error) { default: } - v := buffer.NewViewFromBytes(b) - // We must handle two soft failure conditions simultaneously: // 1. Write may write nothing and return tcpip.ErrWouldBlock. // If this happens, we need to register for notifications if we have @@ -368,22 +367,23 @@ func (c *TCPConn) Write(b []byte) (int, error) { // There is no guarantee that all of the condition #1s will occur before // all of the condition #2s or visa-versa. var ( - err *tcpip.Error - nbytes int - reg bool - notifyCh chan struct{} + r bytes.Reader + nbytes int + entry waiter.Entry + ch <-chan struct{} ) - for nbytes < len(b) && (err == tcpip.ErrWouldBlock || err == nil) { - if err == tcpip.ErrWouldBlock { - if !reg { - // Only register once. - reg = true - - // Create wait queue entry that notifies a channel. - var waitEntry waiter.Entry - waitEntry, notifyCh = waiter.NewChannelEntry(nil) - c.wq.EventRegister(&waitEntry, waiter.EventOut) - defer c.wq.EventUnregister(&waitEntry) + for nbytes != len(b) { + r.Reset(b[nbytes:]) + n, err := c.ep.Write(&r, tcpip.WriteOptions{}) + nbytes += int(n) + switch err { + case nil: + case tcpip.ErrWouldBlock: + if ch == nil { + entry, ch = waiter.NewChannelEntry(nil) + + c.wq.EventRegister(&entry, waiter.EventOut) + defer c.wq.EventUnregister(&entry) } else { // Don't wait immediately after registration in case more data // became available between when we last checked and when we setup @@ -391,22 +391,15 @@ func (c *TCPConn) Write(b []byte) (int, error) { select { case <-deadline: return nbytes, c.newOpError("write", &timeoutError{}) - case <-notifyCh: + case <-ch: + continue } } + default: + return nbytes, c.newOpError("write", errors.New(err.String())) } - - var n int64 - n, err = c.ep.Write(tcpip.SlicePayload(v), tcpip.WriteOptions{}) - nbytes += int(n) - v.TrimFront(int(n)) - } - - if err == nil { - return nbytes, nil } - - return nbytes, c.newOpError("write", errors.New(err.String())) + return nbytes, nil } // Close implements net.Conn.Close. @@ -644,16 +637,18 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { } // If we're being called by Write, there is no addr - wopts := tcpip.WriteOptions{} + writeOptions := tcpip.WriteOptions{} if addr != nil { ua := addr.(*net.UDPAddr) - wopts.To = &tcpip.FullAddress{Addr: tcpip.Address(ua.IP), Port: uint16(ua.Port)} + writeOptions.To = &tcpip.FullAddress{ + Addr: tcpip.Address(ua.IP), + Port: uint16(ua.Port), + } } - v := buffer.NewView(len(b)) - copy(v, b) - - n, err := c.ep.Write(tcpip.SlicePayload(v), wopts) + var r bytes.Reader + r.Reset(b) + n, err := c.ep.Write(&r, writeOptions) if err == tcpip.ErrWouldBlock { // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -666,7 +661,7 @@ func (c *UDPConn) WriteTo(b []byte, addr net.Addr) (int, error) { case <-notifyCh: } - n, err = c.ep.Write(tcpip.SlicePayload(v), wopts) + n, err = c.ep.Write(&r, writeOptions) if err != tcpip.ErrWouldBlock { break } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 4f59e4ff7..fe01029ad 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -29,6 +29,7 @@ package tcpip import ( + "bytes" "errors" "fmt" "io" @@ -471,30 +472,15 @@ type FullAddress struct { // This interface allows the endpoint to request the amount of data it needs // based on internal buffers without exposing them. type Payloader interface { - // FullPayload returns all available bytes. - FullPayload() ([]byte, *Error) + io.Reader - // Payload returns a slice containing at most size bytes. - Payload(size int) ([]byte, *Error) + // Len returns the number of bytes of the unread portion of the + // Reader. + Len() int } -// SlicePayload implements Payloader for slices. -// -// This is typically used for tests. -type SlicePayload []byte - -// FullPayload implements Payloader.FullPayload. -func (s SlicePayload) FullPayload() ([]byte, *Error) { - return s, nil -} - -// Payload implements Payloader.Payload. -func (s SlicePayload) Payload(size int) ([]byte, *Error) { - if size > len(s) { - size = len(s) - } - return s[:size], nil -} +var _ Payloader = (*bytes.Buffer)(nil) +var _ Payloader = (*bytes.Reader)(nil) var _ io.Writer = (*SliceWriter)(nil) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 256e19296..af00ed548 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -313,11 +313,12 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc route = r } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, tcpip.ErrBadBuffer } + var err *tcpip.Error switch e.NetProto { case header.IPv4ProtocolNumber: err = send4(route, e.ID.LocalPort, v, e.ttl, e.owner) diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index c0d6fb442..6fd116a98 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -207,7 +207,7 @@ func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResul return res, nil } -func (*endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tcpip.Error) { +func (*endpoint) Write(tcpip.Payloader, tcpip.WriteOptions) (int64, *tcpip.Error) { // TODO(gvisor.dev/issue/173): Implement. return 0, tcpip.ErrInvalidOptionValue } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index ae743f75e..2dacf5a64 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -280,9 +280,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return 0, tcpip.ErrInvalidEndpointState } - payloadBytes, err := p.FullPayload() - if err != nil { - return 0, err + payloadBytes := make([]byte, p.Len()) + if _, err := io.ReadFull(p, payloadBytes); err != nil { + return 0, tcpip.ErrBadBuffer } // If this is an unassociated socket and callee provided a nonzero diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ea509ac73..8d27d43c2 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1534,14 +1534,19 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc } // Fetch data. - v, perr := p.Payload(avail) - if perr != nil || len(v) == 0 { - // Note that perr may be nil if len(v) == 0. + if l := p.Len(); l < avail { + avail = l + } + if avail == 0 { + return 0, nil + } + v := make([]byte, avail) + if _, err := io.ReadFull(p, v); err != nil { if opts.Atomic { e.sndBufMu.Unlock() e.UnlockUser() } - return 0, perr + return 0, tcpip.ErrBadBuffer } if !opts.Atomic { diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 9f9b3d510..8544fcb08 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -514,9 +514,9 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, *tc return 0, tcpip.ErrBroadcastDisabled } - v, err := p.FullPayload() - if err != nil { - return 0, err + v := make([]byte, p.Len()) + if _, err := io.ReadFull(p, v); err != nil { + return 0, tcpip.ErrBadBuffer } if len(v) > header.UDPMaximumPacketSize { // Payload can't possibly fit in a packet. |