summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rw-r--r--pkg/sentry/socket/netstack/netstack.go17
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go4
-rw-r--r--pkg/tcpip/buffer/view.go14
-rw-r--r--pkg/tcpip/tcpip.go32
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go4
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go4
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go4
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go6
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go4
9 files changed, 57 insertions, 32 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 03749a8bf..22e128b96 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -425,8 +425,13 @@ func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Write
s.readMu.Lock()
defer s.readMu.Unlock()
+ w := tcpip.LimitedWriter{
+ W: dst,
+ N: count,
+ }
+
// This may return a blocking error.
- res, err := s.Endpoint.Read(dst, int(count), tcpip.ReadOptions{
+ res, err := s.Endpoint.Read(&w, tcpip.ReadOptions{
Peek: dup,
})
if err != nil {
@@ -2579,7 +2584,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
// caller-supplied buffer.
var w io.Writer
if !isPacket && trunc {
- w = ioutil.Discard
+ w = &tcpip.LimitedWriter{
+ W: ioutil.Discard,
+ N: dst.NumBytes(),
+ }
} else {
w = dst.Writer(ctx)
}
@@ -2587,7 +2595,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
s.readMu.Lock()
defer s.readMu.Unlock()
- res, err := s.Endpoint.Read(w, int(dst.NumBytes()), readOptions)
+ res, err := s.Endpoint.Read(w, readOptions)
+ if err == tcpip.ErrBadBuffer && dst.NumBytes() == 0 {
+ err = nil
+ }
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index 85a0b8b90..fdeec12d3 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -295,7 +295,7 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s
w := tcpip.SliceWriter(b)
opts := tcpip.ReadOptions{NeedRemoteAddr: addr != nil}
- res, err := ep.Read(&w, len(b), opts)
+ res, err := ep.Read(&w, opts)
if err == tcpip.ErrWouldBlock {
// Create wait queue entry that notifies a channel.
@@ -303,7 +303,7 @@ func commonRead(b []byte, ep tcpip.Endpoint, wq *waiter.Queue, deadline <-chan s
wq.EventRegister(&waitEntry, waiter.EventIn)
defer wq.EventUnregister(&waitEntry)
for {
- res, err = ep.Read(&w, len(b), opts)
+ res, err = ep.Read(&w, opts)
if err != tcpip.ErrWouldBlock {
break
}
diff --git a/pkg/tcpip/buffer/view.go b/pkg/tcpip/buffer/view.go
index 09d3dac66..91cc62cc8 100644
--- a/pkg/tcpip/buffer/view.go
+++ b/pkg/tcpip/buffer/view.go
@@ -148,23 +148,13 @@ func (vv *VectorisedView) ReadToVV(dstVV *VectorisedView, count int) (copied int
// ReadTo reads up to count bytes from vv to dst. It also removes them from vv
// unless peek is true.
-func (vv *VectorisedView) ReadTo(dst io.Writer, count int, peek bool) (int, error) {
+func (vv *VectorisedView) ReadTo(dst io.Writer, peek bool) (int, error) {
var err error
done := 0
for _, v := range vv.Views() {
- remaining := count - done
- if remaining <= 0 {
- break
- }
- if len(v) > remaining {
- v = v[:remaining]
- }
-
var n int
n, err = dst.Write(v)
- if n > 0 {
- done += n
- }
+ done += n
if err != nil {
break
}
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 49d4912ad..56aac093c 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -505,10 +505,34 @@ type SliceWriter []byte
func (s *SliceWriter) Write(b []byte) (int, error) {
n := copy(*s, b)
*s = (*s)[n:]
- if n < len(b) {
- return n, io.ErrShortWrite
+ var err error
+ if n != len(b) {
+ err = io.ErrShortWrite
}
- return n, nil
+ return n, err
+}
+
+var _ io.Writer = (*LimitedWriter)(nil)
+
+// A LimitedWriter writes to W but limits the amount of data copied to just N
+// bytes. Each call to Write updates N to reflect the new amount remaining.
+type LimitedWriter struct {
+ W io.Writer
+ N int64
+}
+
+func (l *LimitedWriter) Write(p []byte) (int, error) {
+ pLen := int64(len(p))
+ if pLen > l.N {
+ p = p[:l.N]
+ }
+ n, err := l.W.Write(p)
+ n64 := int64(n)
+ if err == nil && n64 != pLen {
+ err = io.ErrShortWrite
+ }
+ l.N -= n64
+ return n, err
}
// A ControlMessages contains socket control messages for IP sockets.
@@ -623,7 +647,7 @@ type Endpoint interface {
// If non-zero number of bytes are successfully read and written to dst, err
// must be nil. Otherwise, if dst failed to write anything, ErrBadBuffer
// should be returned.
- Read(dst io.Writer, count int, opts ReadOptions) (res ReadResult, err *Error)
+ Read(dst io.Writer, opts ReadOptions) (res ReadResult, err *Error)
// Write writes data to the endpoint's peer. This method does not block if
// the data cannot be written.
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index 87277fbd3..256e19296 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -154,7 +154,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvMu.Lock()
if e.rcvList.Empty() {
@@ -186,7 +186,7 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip
res.RemoteAddr = p.senderAddress
}
- n, err := p.data.ReadTo(dst, count, opts.Peek)
+ n, err := p.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
return res, tcpip.ErrBadBuffer
}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index c3b3b8d34..c0d6fb442 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -162,7 +162,7 @@ func (ep *endpoint) Close() {
func (ep *endpoint) ModerateRecvBuf(copied int) {}
// Read implements tcpip.Endpoint.Read.
-func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (ep *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
ep.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -199,7 +199,7 @@ func (ep *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpi
res.LinkPacketInfo = packet.packetInfo
}
- n, err := packet.data.ReadTo(dst, count, opts.Peek)
+ n, err := packet.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
return res, tcpip.ErrBadBuffer
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index 425bcf3ee..ae743f75e 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -191,7 +191,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvMu.Lock()
// If there's no data to read, return that read would block or that the
@@ -225,7 +225,7 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip
res.RemoteAddr = pkt.senderAddr
}
- n, err := pkt.data.ReadTo(dst, count, opts.Peek)
+ n, err := pkt.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
return res, tcpip.ErrBadBuffer
}
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index a4508e871..ea509ac73 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1328,7 +1328,7 @@ func (e *endpoint) UpdateLastError(err *tcpip.Error) {
}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
e.rcvReadMu.Lock()
defer e.rcvReadMu.Unlock()
@@ -1346,9 +1346,9 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip
var err error
done := 0
s := first
- for s != nil && done < count {
+ for s != nil {
var n int
- n, err = s.data.ReadTo(dst, count-done, opts.Peek)
+ n, err = s.data.ReadTo(dst, opts.Peek)
// Book keeping first then error handling.
done += n
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 520a0ac9d..9f9b3d510 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -284,7 +284,7 @@ func (e *endpoint) Close() {
func (e *endpoint) ModerateRecvBuf(copied int) {}
// Read implements tcpip.Endpoint.Read.
-func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
+func (e *endpoint) Read(dst io.Writer, opts tcpip.ReadOptions) (tcpip.ReadResult, *tcpip.Error) {
if err := e.LastError(); err != nil {
return tcpip.ReadResult{}, err
}
@@ -340,7 +340,7 @@ func (e *endpoint) Read(dst io.Writer, count int, opts tcpip.ReadOptions) (tcpip
res.RemoteAddr = p.senderAddress
}
- n, err := p.data.ReadTo(dst, count, opts.Peek)
+ n, err := p.data.ReadTo(dst, opts.Peek)
if n == 0 && err != nil {
return res, tcpip.ErrBadBuffer
}