diff options
Diffstat (limited to 'pkg/tcpip/adapters/gonet/gonet.go')
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet.go | 62 |
1 files changed, 33 insertions, 29 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 81428770b..8b077156c 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -30,7 +30,10 @@ import ( "gvisor.googlesource.com/gvisor/pkg/waiter" ) -var errCanceled = errors.New("operation canceled") +var ( + errCanceled = errors.New("operation canceled") + errWouldBlock = errors.New("operation would block") +) // timeoutError is how the net package reports timeouts. type timeoutError struct{} @@ -277,10 +280,19 @@ type opErrorer interface { // commonRead implements the common logic between net.Conn.Read and // net.PacketConn.ReadFrom. -func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) ([]byte, error) { +func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer, dontWait bool) ([]byte, error) { + select { + case <-deadline: + return nil, errorer.newOpError("read", &timeoutError{}) + default: + } + read, _, err := ep.Read(addr) if err == tcpip.ErrWouldBlock { + if dontWait { + return nil, errWouldBlock + } // Create wait queue entry that notifies a channel. waitEntry, notifyCh := waiter.NewChannelEntry(nil) wq.EventRegister(&waitEntry, waiter.EventIn) @@ -316,27 +328,26 @@ func (c *Conn) Read(b []byte) (int, error) { deadline := c.readCancel() - // Check if deadline has already expired. - select { - case <-deadline: - return 0, c.newOpError("read", &timeoutError{}) - default: - } - - if len(c.read) == 0 { - var err error - c.read, err = commonRead(c.ep, c.wq, deadline, nil, c) - if err != nil { - return 0, err + numRead := 0 + for numRead != len(b) { + if len(c.read) == 0 { + var err error + c.read, err = commonRead(c.ep, c.wq, deadline, nil, c, numRead != 0) + if err != nil { + if numRead != 0 { + return numRead, nil + } + return numRead, err + } + } + n := copy(b[numRead:], c.read) + c.read.TrimFront(n) + numRead += n + if len(c.read) == 0 { + c.read = nil } } - - n := copy(b, c.read) - c.read.TrimFront(n) - if len(c.read) == 0 { - c.read = nil - } - return n, nil + return numRead, nil } // Write implements net.Conn.Write. @@ -550,15 +561,8 @@ func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *ne func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) { deadline := c.readCancel() - // Check if deadline has already expired. - select { - case <-deadline: - return 0, nil, c.newOpError("read", &timeoutError{}) - default: - } - var addr tcpip.FullAddress - read, err := commonRead(c.ep, c.wq, deadline, &addr, c) + read, err := commonRead(c.ep, c.wq, deadline, &addr, c, false) if err != nil { return 0, nil, err } |