summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go62
1 files changed, 33 insertions, 29 deletions
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 81428770b..8b077156c 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -30,7 +30,10 @@ import (
"gvisor.googlesource.com/gvisor/pkg/waiter"
)
-var errCanceled = errors.New("operation canceled")
+var (
+ errCanceled = errors.New("operation canceled")
+ errWouldBlock = errors.New("operation would block")
+)
// timeoutError is how the net package reports timeouts.
type timeoutError struct{}
@@ -277,10 +280,19 @@ type opErrorer interface {
// commonRead implements the common logic between net.Conn.Read and
// net.PacketConn.ReadFrom.
-func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer) ([]byte, error) {
+func commonRead(ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan struct{}, addr *tcpip.FullAddress, errorer opErrorer, dontWait bool) ([]byte, error) {
+ select {
+ case <-deadline:
+ return nil, errorer.newOpError("read", &timeoutError{})
+ default:
+ }
+
read, _, err := ep.Read(addr)
if err == tcpip.ErrWouldBlock {
+ if dontWait {
+ return nil, errWouldBlock
+ }
// Create wait queue entry that notifies a channel.
waitEntry, notifyCh := waiter.NewChannelEntry(nil)
wq.EventRegister(&waitEntry, waiter.EventIn)
@@ -316,27 +328,26 @@ func (c *Conn) Read(b []byte) (int, error) {
deadline := c.readCancel()
- // Check if deadline has already expired.
- select {
- case <-deadline:
- return 0, c.newOpError("read", &timeoutError{})
- default:
- }
-
- if len(c.read) == 0 {
- var err error
- c.read, err = commonRead(c.ep, c.wq, deadline, nil, c)
- if err != nil {
- return 0, err
+ numRead := 0
+ for numRead != len(b) {
+ if len(c.read) == 0 {
+ var err error
+ c.read, err = commonRead(c.ep, c.wq, deadline, nil, c, numRead != 0)
+ if err != nil {
+ if numRead != 0 {
+ return numRead, nil
+ }
+ return numRead, err
+ }
+ }
+ n := copy(b[numRead:], c.read)
+ c.read.TrimFront(n)
+ numRead += n
+ if len(c.read) == 0 {
+ c.read = nil
}
}
-
- n := copy(b, c.read)
- c.read.TrimFront(n)
- if len(c.read) == 0 {
- c.read = nil
- }
- return n, nil
+ return numRead, nil
}
// Write implements net.Conn.Write.
@@ -550,15 +561,8 @@ func (c *PacketConn) newRemoteOpError(op string, remote net.Addr, err error) *ne
func (c *PacketConn) ReadFrom(b []byte) (int, net.Addr, error) {
deadline := c.readCancel()
- // Check if deadline has already expired.
- select {
- case <-deadline:
- return 0, nil, c.newOpError("read", &timeoutError{})
- default:
- }
-
var addr tcpip.FullAddress
- read, err := commonRead(c.ep, c.wq, deadline, &addr, c)
+ read, err := commonRead(c.ep, c.wq, deadline, &addr, c, false)
if err != nil {
return 0, nil, err
}