From 6c0e1d9cfe6adbfbb32e7020d6426608ac63ad37 Mon Sep 17 00:00:00 2001 From: Tamir Duberstein Date: Fri, 22 Jan 2021 12:24:20 -0800 Subject: Define tcpip.Payloader in terms of io.Reader Fixes #1509. PiperOrigin-RevId: 353295589 --- pkg/usermem/usermem.go | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) (limited to 'pkg/usermem') diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go index 79db8895b..dc2571154 100644 --- a/pkg/usermem/usermem.go +++ b/pkg/usermem/usermem.go @@ -517,28 +517,29 @@ func (s IOSequence) CopyInTo(ctx context.Context, dst safemem.Writer) (int64, er // Reader returns an io.Reader that reads from s. Reads beyond the end of s // return io.EOF. The preconditions that apply to s.CopyIn also apply to the // returned io.Reader.Read. -func (s IOSequence) Reader(ctx context.Context) io.Reader { - return &ioSequenceReadWriter{ctx, s} +func (s IOSequence) Reader(ctx context.Context) *IOSequenceReadWriter { + return &IOSequenceReadWriter{ctx, s} } // Writer returns an io.Writer that writes to s. Writes beyond the end of s // return ErrEndOfIOSequence. The preconditions that apply to s.CopyOut also // apply to the returned io.Writer.Write. -func (s IOSequence) Writer(ctx context.Context) io.Writer { - return &ioSequenceReadWriter{ctx, s} +func (s IOSequence) Writer(ctx context.Context) *IOSequenceReadWriter { + return &IOSequenceReadWriter{ctx, s} } // ErrEndOfIOSequence is returned by IOSequence.Writer().Write() when // attempting to write beyond the end of the IOSequence. var ErrEndOfIOSequence = errors.New("write beyond end of IOSequence") -type ioSequenceReadWriter struct { +// IOSequenceReadWriter implements io.Reader and io.Writer for an IOSequence. +type IOSequenceReadWriter struct { ctx context.Context s IOSequence } // Read implements io.Reader.Read. -func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) { +func (rw *IOSequenceReadWriter) Read(dst []byte) (int, error) { n, err := rw.s.CopyIn(rw.ctx, dst) rw.s = rw.s.DropFirst(n) if err == nil && rw.s.NumBytes() == 0 { @@ -547,8 +548,13 @@ func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) { return n, err } +// Len implements tcpip.Payloader. +func (rw *IOSequenceReadWriter) Len() int { + return int(rw.s.NumBytes()) +} + // Write implements io.Writer.Write. -func (rw *ioSequenceReadWriter) Write(src []byte) (int, error) { +func (rw *IOSequenceReadWriter) Write(src []byte) (int, error) { n, err := rw.s.CopyOut(rw.ctx, src) rw.s = rw.s.DropFirst(n) if err == nil && n < len(src) { -- cgit v1.2.3