diff options
Diffstat (limited to 'pkg/sentry/kernel/pipe/pipe_util.go')
-rw-r--r-- | pkg/sentry/kernel/pipe/pipe_util.go | 99 |
1 files changed, 41 insertions, 58 deletions
diff --git a/pkg/sentry/kernel/pipe/pipe_util.go b/pkg/sentry/kernel/pipe/pipe_util.go index f665920cb..77246edbe 100644 --- a/pkg/sentry/kernel/pipe/pipe_util.go +++ b/pkg/sentry/kernel/pipe/pipe_util.go @@ -21,9 +21,9 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/amutex" - "gvisor.dev/gvisor/pkg/buffer" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/marshal/primitive" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" @@ -44,46 +44,37 @@ func (p *Pipe) Release(context.Context) { // Read reads from the Pipe into dst. func (p *Pipe) Read(ctx context.Context, dst usermem.IOSequence) (int64, error) { - n, err := p.read(ctx, readOps{ - left: func() int64 { - return dst.NumBytes() - }, - limit: func(l int64) { - dst = dst.TakeFirst64(l) - }, - read: func(view *buffer.View) (int64, error) { - n, err := dst.CopyOutFrom(ctx, view) - dst = dst.DropFirst64(n) - view.TrimFront(n) - return n, err - }, - }) + n, err := dst.CopyOutFrom(ctx, p) if n > 0 { p.Notify(waiter.EventOut) } return n, err } +// ReadToBlocks implements safemem.Reader.ReadToBlocks for Pipe.Read. +func (p *Pipe) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { + n, err := p.read(int64(dsts.NumBytes()), func(srcs safemem.BlockSeq) (uint64, error) { + return safemem.CopySeq(dsts, srcs) + }, true /* removeFromSrc */) + return uint64(n), err +} + +func (p *Pipe) read(count int64, f func(srcs safemem.BlockSeq) (uint64, error), removeFromSrc bool) (int64, error) { + p.mu.Lock() + defer p.mu.Unlock() + n, err := p.peekLocked(count, f) + if n > 0 && removeFromSrc { + p.consumeLocked(n) + } + return n, err +} + // WriteTo writes to w from the Pipe. func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) (int64, error) { - ops := readOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - read: func(view *buffer.View) (int64, error) { - n, err := view.ReadToWriter(w, count) - if !dup { - view.TrimFront(n) - } - count -= n - return n, err - }, - } - n, err := p.read(ctx, ops) - if n > 0 { + n, err := p.read(count, func(srcs safemem.BlockSeq) (uint64, error) { + return safemem.FromIOWriter{w}.WriteFromBlocks(srcs) + }, !dup /* removeFromSrc */) + if n > 0 && !dup { p.Notify(waiter.EventOut) } return n, err @@ -91,39 +82,31 @@ func (p *Pipe) WriteTo(ctx context.Context, w io.Writer, count int64, dup bool) // Write writes to the Pipe from src. func (p *Pipe) Write(ctx context.Context, src usermem.IOSequence) (int64, error) { - n, err := p.write(ctx, writeOps{ - left: func() int64 { - return src.NumBytes() - }, - limit: func(l int64) { - src = src.TakeFirst64(l) - }, - write: func(view *buffer.View) (int64, error) { - n, err := src.CopyInTo(ctx, view) - src = src.DropFirst64(n) - return n, err - }, - }) + n, err := src.CopyInTo(ctx, p) if n > 0 { p.Notify(waiter.EventIn) } return n, err } +// WriteFromBlocks implements safemem.Writer.WriteFromBlocks for Pipe.Write. +func (p *Pipe) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { + n, err := p.write(int64(srcs.NumBytes()), func(dsts safemem.BlockSeq) (uint64, error) { + return safemem.CopySeq(dsts, srcs) + }) + return uint64(n), err +} + +func (p *Pipe) write(count int64, f func(safemem.BlockSeq) (uint64, error)) (int64, error) { + p.mu.Lock() + defer p.mu.Unlock() + return p.writeLocked(count, f) +} + // ReadFrom reads from r to the Pipe. func (p *Pipe) ReadFrom(ctx context.Context, r io.Reader, count int64) (int64, error) { - n, err := p.write(ctx, writeOps{ - left: func() int64 { - return count - }, - limit: func(l int64) { - count = l - }, - write: func(view *buffer.View) (int64, error) { - n, err := view.WriteFromReader(r, count) - count -= n - return n, err - }, + n, err := p.write(count, func(dsts safemem.BlockSeq) (uint64, error) { + return safemem.FromIOReader{r}.ReadToBlocks(dsts) }) if n > 0 { p.Notify(waiter.EventIn) |