summaryrefslogtreecommitdiffhomepage
path: root/pkg/usermem
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/usermem')
-rw-r--r--pkg/usermem/usermem.go20
1 files changed, 13 insertions, 7 deletions
diff --git a/pkg/usermem/usermem.go b/pkg/usermem/usermem.go
index 79db8895b..dc2571154 100644
--- a/pkg/usermem/usermem.go
+++ b/pkg/usermem/usermem.go
@@ -517,28 +517,29 @@ func (s IOSequence) CopyInTo(ctx context.Context, dst safemem.Writer) (int64, er
// Reader returns an io.Reader that reads from s. Reads beyond the end of s
// return io.EOF. The preconditions that apply to s.CopyIn also apply to the
// returned io.Reader.Read.
-func (s IOSequence) Reader(ctx context.Context) io.Reader {
- return &ioSequenceReadWriter{ctx, s}
+func (s IOSequence) Reader(ctx context.Context) *IOSequenceReadWriter {
+ return &IOSequenceReadWriter{ctx, s}
}
// Writer returns an io.Writer that writes to s. Writes beyond the end of s
// return ErrEndOfIOSequence. The preconditions that apply to s.CopyOut also
// apply to the returned io.Writer.Write.
-func (s IOSequence) Writer(ctx context.Context) io.Writer {
- return &ioSequenceReadWriter{ctx, s}
+func (s IOSequence) Writer(ctx context.Context) *IOSequenceReadWriter {
+ return &IOSequenceReadWriter{ctx, s}
}
// ErrEndOfIOSequence is returned by IOSequence.Writer().Write() when
// attempting to write beyond the end of the IOSequence.
var ErrEndOfIOSequence = errors.New("write beyond end of IOSequence")
-type ioSequenceReadWriter struct {
+// IOSequenceReadWriter implements io.Reader and io.Writer for an IOSequence.
+type IOSequenceReadWriter struct {
ctx context.Context
s IOSequence
}
// Read implements io.Reader.Read.
-func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) {
+func (rw *IOSequenceReadWriter) Read(dst []byte) (int, error) {
n, err := rw.s.CopyIn(rw.ctx, dst)
rw.s = rw.s.DropFirst(n)
if err == nil && rw.s.NumBytes() == 0 {
@@ -547,8 +548,13 @@ func (rw *ioSequenceReadWriter) Read(dst []byte) (int, error) {
return n, err
}
+// Len implements tcpip.Payloader.
+func (rw *IOSequenceReadWriter) Len() int {
+ return int(rw.s.NumBytes())
+}
+
// Write implements io.Writer.Write.
-func (rw *ioSequenceReadWriter) Write(src []byte) (int, error) {
+func (rw *IOSequenceReadWriter) Write(src []byte) (int, error) {
n, err := rw.s.CopyOut(rw.ctx, src)
rw.s = rw.s.DropFirst(n)
if err == nil && n < len(src) {