diff options
Diffstat (limited to 'pkg')
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 27 | ||||
-rw-r--r-- | pkg/tcpip/transport/tcp/endpoint.go | 85 |
2 files changed, 73 insertions, 39 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 8cb5c823f..0f2cd05fc 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -429,6 +429,11 @@ func (i *ioSequencePayload) Size() int { return int(i.src.NumBytes()) } +// DropFirst drops the first n bytes from underlying src. +func (i *ioSequencePayload) DropFirst(n int) { + i.src = i.src.DropFirst(int(n)) +} + // Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { f := &ioSequencePayload{ctx: ctx, src: src} @@ -2026,28 +2031,22 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] addr = &addrBuf } - v := buffer.NewView(int(src.NumBytes())) - - // Copy all the data into the buffer. - if _, err := src.CopyIn(t, v); err != nil { - return 0, syserr.FromError(err) - } - opts := tcpip.WriteOptions{ To: addr, More: flags&linux.MSG_MORE != 0, EndOfRecord: flags&linux.MSG_EOR != 0, } - n, resCh, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts) + v := &ioSequencePayload{t, src} + n, resCh, err := s.Endpoint.Write(v, opts) if resCh != nil { if err := t.Block(resCh); err != nil { return 0, syserr.FromError(err) } - n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts) + n, _, err = s.Endpoint.Write(v, opts) } dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= uintptr(len(v)) || dontWait) { + if err == nil && (n >= uintptr(v.Size()) || dontWait) { // Complete write. return int(n), nil } @@ -2061,18 +2060,18 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] s.EventRegister(&e, waiter.EventOut) defer s.EventUnregister(&e) - v.TrimFront(int(n)) + v.DropFirst(int(n)) total := n for { - n, _, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts) - v.TrimFront(int(n)) + n, _, err = s.Endpoint.Write(v, opts) + v.DropFirst(int(n)) total += n if err != nil && err != tcpip.ErrWouldBlock && total == 0 { return 0, syserr.TranslateNetstackError(err) } - if err == nil && len(v) == 0 || err != nil && err != tcpip.ErrWouldBlock { + if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock { return int(total), nil } diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index e67169111..7c42a830a 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -878,60 +878,95 @@ func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { return v, nil } -// Write writes data to the endpoint's peer. -func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { - // Linux completely ignores any address passed to sendto(2) for TCP sockets - // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More - // and opts.EndOfRecord are also ignored. - - e.mu.RLock() - defer e.mu.RUnlock() - +// isEndpointWritableLocked checks if a given endpoint is writable +// and also returns the number of bytes that can be written at this +// moment. If the endpoint is not writable then it returns an error +// indicating the reason why it's not writable. +// Caller must hold e.mu and e.sndBufMu +func (e *endpoint) isEndpointWritableLocked() (int, *tcpip.Error) { // The endpoint cannot be written to if it's not connected. if !e.state.connected() { switch e.state { case StateError: - return 0, nil, e.hardError + return 0, e.hardError default: - return 0, nil, tcpip.ErrClosedForSend + return 0, tcpip.ErrClosedForSend } } - // Nothing to do if the buffer is empty. - if p.Size() == 0 { - return 0, nil, nil - } - - e.sndBufMu.Lock() - // Check if the connection has already been closed for sends. if e.sndClosed { - e.sndBufMu.Unlock() - return 0, nil, tcpip.ErrClosedForSend + return 0, tcpip.ErrClosedForSend } - // Check against the limit. avail := e.sndBufSize - e.sndBufUsed if avail <= 0 { + return 0, tcpip.ErrWouldBlock + } + return avail, nil +} + +// Write writes data to the endpoint's peer. +func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-chan struct{}, *tcpip.Error) { + // Linux completely ignores any address passed to sendto(2) for TCP sockets + // (without the MSG_FASTOPEN flag). Corking is unimplemented, so opts.More + // and opts.EndOfRecord are also ignored. + + e.mu.RLock() + e.sndBufMu.Lock() + + avail, err := e.isEndpointWritableLocked() + if err != nil { e.sndBufMu.Unlock() - return 0, nil, tcpip.ErrWouldBlock + e.mu.RUnlock() + return 0, nil, err + } + + e.sndBufMu.Unlock() + e.mu.RUnlock() + + // Nothing to do if the buffer is empty. + if p.Size() == 0 { + return 0, nil, nil } + // Copy in memory without holding sndBufMu so that worker goroutine can + // make progress independent of this operation. v, perr := p.Get(avail) if perr != nil { - e.sndBufMu.Unlock() return 0, nil, perr } - l := len(v) - s := newSegmentFromView(&e.route, e.id, v) + e.mu.RLock() + e.sndBufMu.Lock() + + // Because we released the lock before copying, check state again + // to make sure the endpoint is still in a valid state for a + // write. + avail, err = e.isEndpointWritableLocked() + if err != nil { + e.sndBufMu.Unlock() + e.mu.RUnlock() + return 0, nil, err + } + + // Discard any excess data copied in due to avail being reduced due to a + // simultaneous write call to the socket. + if avail < len(v) { + v = v[:avail] + } // Add data to the send queue. + l := len(v) + s := newSegmentFromView(&e.route, e.id, v) e.sndBufUsed += l e.sndBufInQueue += seqnum.Size(l) e.sndQueue.PushBack(s) e.sndBufMu.Unlock() + // Release the endpoint lock to prevent deadlocks due to lock + // order inversion when acquiring workMu. + e.mu.RUnlock() if e.workMu.TryLock() { // Do the work inline. |