diff options
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 17 | ||||
-rw-r--r-- | pkg/tcpip/adapters/gonet/gonet.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/buffer/view.go | 14 | ||||
-rw-r--r-- | pkg/tcpip/tcpip.go | 32 | ||||
-rw-r--r-- | pkg/tcpip/transport/icmp/endpoint.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/packet/endpoint.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/raw/endpoint.go | 4 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 6 | ||||
-rw-r--r-- | pkg/tcpip/transport/udp/endpoint.go | 4 |
9 files changed, 57 insertions, 32 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 03749a8bf..22e128b96 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -425,8 +425,13 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write s.readMu.Lock() defer s.readMu.Unlock() + w := tcpip.LimitedWriter{ + W: dst, + N: count, + } + // This may return a blocking error. - res, err := s.Endpoint.Read(dst, int(count), tcpip.ReadOptions{ + res, err := s.Endpoint.Read(&w, tcpip.ReadOptions{ Peek: dup, }) if err != nil { @@ -2579,7 +2584,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq // caller-supplied buffer. var w io.Writer if !isPacket && trunc { - w = ioutil.Discard + w = &tcpip.LimitedWriter{ + W: ioutil.Discard, + N: dst.NumBytes(), + } } else { w = dst.Writer(ctx) } @@ -2587,7 +2595,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq s.readMu.Lock() defer s.readMu.Unlock() - res, err := s.Endpoint.Read(w, int(dst.NumBytes()), readOptions) + res, err := s.Endpoint.Read(w, readOptions) + if err == tcpip.ErrBadBuffer && dst.NumBytes() == 0 { + err = nil + } if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err) } diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index 85a0b8b90..fdeec12d3 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -295,7 +295,7 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s w := tcpip.SliceWriter(b) opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil} - res, err := ep.Read(&w, len(b), opts) + res, err := ep.Read(&w, opts) if err == tcpip.ErrWouldBlock { // Create wait queue entry that notifies a channel. @@ -303,7 +303,7 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s wq.EventRegister(&waitEntry, waiter.EventIn) defer wq.EventUnregister(&waitEntry) for { - res, err = ep.Read(&w, len(b), opts) + res, err = ep.Read(&w, opts) if err != tcpip.ErrWouldBlock { break } diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go index 09d3dac66..91cc62cc8 100644 --- a/pkg/tcpip/buffer/view.go +++ b/pkg/tcpip/buffer/view.go @@ -148,23 +148,13 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int // ReadTo reads up to count bytes from vv to dst. It also removes them from vv // unless peek is true. -func (vv *VectorisedView) ReadTo(dst io.Writer, count int, peek bool) (int, error) { +func (vv *VectorisedView) ReadTo(dst io.Writer, peek bool) (int, error) { var err error done := 0 for _, v := range vv.Views() { - remaining := count - done - if remaining <= 0 { - break - } - if len(v) > remaining { - v = v[:remaining] - } - var n int n, err = dst.Write(v) - if n > 0 { - done += n - } + done += n if err != nil { break } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 49d4912ad..56aac093c 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -505,10 +505,34 @@ type SliceWriter []byte func (s *SliceWriter) Write(b []byte) (int, error) { n := copy(*s, b) *s = (*s)[n:] - if n < len(b) { - return n, io.ErrShortWrite + var err error + if n != len(b) { + err = io.ErrShortWrite } - return n, nil + return n, err +} + +var _ io.Writer = (*LimitedWriter)(nil) + +// A LimitedWriter writes to W but limits the amount of data copied to just N +// bytes. Each call to Write updates N to reflect the new amount remaining. +type LimitedWriter struct { + W io.Writer + N int64 +} + +func (l *LimitedWriter) Write(p []byte) (int, error) { + pLen := int64(len(p)) + if pLen > l.N { + p = p[:l.N] + } + n, err := l.W.Write(p) + n64 := int64(n) + if err == nil && n64 != pLen { + err = io.ErrShortWrite + } + l.N -= n64 + return n, err } // A ControlMessages contains socket control messages for IP sockets. @@ -623,7 +647,7 @@ type Endpoint interface { // If non-zero number of bytes are successfully read and written to dst, err // must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer // should be returned. - Read(dst io.Writer, count int, opts ReadOptions) (res ReadResult, err *Error) + Read(dst io.Writer, opts ReadOptions) (res ReadResult, err *Error) // Write writes data to the endpoint's peer. This method does not block if // the data cannot be written. diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 87277fbd3..256e19296 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -154,7 +154,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { e.rcvMu.Lock() if e.rcvList.Empty() { @@ -186,7 +186,7 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip res.RemoteAddr = p.senderAddress } - n, err := p.data.ReadTo(dst, count, opts.Peek) + n, err := p.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { return res, tcpip.ErrBadBuffer } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index c3b3b8d34..c0d6fb442 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -162,7 +162,7 @@ func (ep *endpoint) Close() { func (ep *endpoint) ModerateRecvBuf(copied int) {} // Read implements tcpip.Endpoint.Read. -func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { ep.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -199,7 +199,7 @@ func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpi res.LinkPacketInfo = packet.packetInfo } - n, err := packet.data.ReadTo(dst, count, opts.Peek) + n, err := packet.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { return res, tcpip.ErrBadBuffer } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 425bcf3ee..ae743f75e 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -191,7 +191,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { e.rcvMu.Lock() // If there's no data to read, return that read would block or that the @@ -225,7 +225,7 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip res.RemoteAddr = pkt.senderAddr } - n, err := pkt.data.ReadTo(dst, count, opts.Peek) + n, err := pkt.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { return res, tcpip.ErrBadBuffer } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index a4508e871..ea509ac73 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1328,7 +1328,7 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) { } // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { e.rcvReadMu.Lock() defer e.rcvReadMu.Unlock() @@ -1346,9 +1346,9 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip var err error done := 0 s := first - for s != nil && done < count { + for s != nil { var n int - n, err = s.data.ReadTo(dst, count-done, opts.Peek) + n, err = s.data.ReadTo(dst, opts.Peek) // Book keeping first then error handling. done += n diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 520a0ac9d..9f9b3d510 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -284,7 +284,7 @@ func (e *endpoint) Close() { func (e *endpoint) ModerateRecvBuf(copied int) {} // Read implements tcpip.Endpoint.Read. -func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { +func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) { if err := e.LastError(); err != nil { return tcpip.ReadResult{}, err } @@ -340,7 +340,7 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip res.RemoteAddr = p.senderAddress } - n, err := p.data.ReadTo(dst, count, opts.Peek) + n, err := p.data.ReadTo(dst, opts.Peek) if n == 0 && err != nil { return res, tcpip.ErrBadBuffer } |