diff options
Diffstat (limited to 'pkg/secio')
-rw-r--r-- | pkg/secio/BUILD | 19 | ||||
-rw-r--r-- | pkg/secio/full_reader.go | 34 | ||||
-rw-r--r-- | pkg/secio/secio.go | 105 | ||||
-rw-r--r-- | pkg/secio/secio_test.go | 126 |
4 files changed, 284 insertions, 0 deletions
diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD new file mode 100644 index 000000000..60f63c7a6 --- /dev/null +++ b/pkg/secio/BUILD @@ -0,0 +1,19 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "secio", + srcs = [ + "full_reader.go", + "secio.go", + ], + visibility = ["//pkg/sentry:internal"], +) + +go_test( + name = "secio_test", + size = "small", + srcs = ["secio_test.go"], + library = ":secio", +) diff --git a/pkg/secio/full_reader.go b/pkg/secio/full_reader.go new file mode 100644 index 000000000..aed2564bd --- /dev/null +++ b/pkg/secio/full_reader.go @@ -0,0 +1,34 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package secio + +import ( + "io" +) + +// FullReader adapts an io.Reader to never return partial reads with a nil +// error. +type FullReader struct { + Reader io.Reader +} + +// Read implements io.Reader.Read. +func (r FullReader) Read(dst []byte) (int, error) { + n, err := io.ReadFull(r.Reader, dst) + if err == io.ErrUnexpectedEOF { + return n, io.EOF + } + return n, err +} diff --git a/pkg/secio/secio.go b/pkg/secio/secio.go new file mode 100644 index 000000000..b43226035 --- /dev/null +++ b/pkg/secio/secio.go @@ -0,0 +1,105 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package secio provides support for sectioned I/O. +package secio + +import ( + "errors" + "io" +) + +// ErrReachedLimit is returned when SectionReader.Read or SectionWriter.Write +// reaches its limit. +var ErrReachedLimit = errors.New("reached limit") + +// SectionReader implements io.Reader on a section of an underlying io.ReaderAt. +// It is similar to io.SectionReader, but: +// +// - Reading beyond the limit returns ErrReachedLimit, not io.EOF. +// +// - Limit overflow is handled correctly. +type SectionReader struct { + r io.ReaderAt + off int64 + limit int64 +} + +// Read implements io.Reader.Read. +func (r *SectionReader) Read(dst []byte) (int, error) { + if r.limit >= 0 { + if max := r.limit - r.off; max < int64(len(dst)) { + dst = dst[:max] + } + } + n, err := r.r.ReadAt(dst, r.off) + r.off += int64(n) + if err == nil && r.off == r.limit { + err = ErrReachedLimit + } + return n, err +} + +// NewOffsetReader returns an io.Reader that reads from r starting at offset +// off. +func NewOffsetReader(r io.ReaderAt, off int64) *SectionReader { + return &SectionReader{r, off, -1} +} + +// NewSectionReader returns an io.Reader that reads from r starting at offset +// off and stops with ErrReachedLimit after n bytes. +func NewSectionReader(r io.ReaderAt, off int64, n int64) *SectionReader { + // If off + n overflows, it will be < 0 such that no limit applies, but + // this is the correct behavior as long as r prohibits reading at offsets + // beyond MaxInt64. + return &SectionReader{r, off, off + n} +} + +// SectionWriter implements io.Writer on a section of an underlying +// io.WriterAt. Writing beyond the limit returns ErrReachedLimit. +type SectionWriter struct { + w io.WriterAt + off int64 + limit int64 +} + +// Write implements io.Writer.Write. +func (w *SectionWriter) Write(src []byte) (int, error) { + if w.limit >= 0 { + if max := w.limit - w.off; max < int64(len(src)) { + src = src[:max] + } + } + n, err := w.w.WriteAt(src, w.off) + w.off += int64(n) + if err == nil && w.off == w.limit { + err = ErrReachedLimit + } + return n, err +} + +// NewOffsetWriter returns an io.Writer that writes to w starting at offset +// off. +func NewOffsetWriter(w io.WriterAt, off int64) *SectionWriter { + return &SectionWriter{w, off, -1} +} + +// NewSectionWriter returns an io.Writer that writes to w starting at offset +// off and stops with ErrReachedLimit after n bytes. +func NewSectionWriter(w io.WriterAt, off int64, n int64) *SectionWriter { + // If off + n overflows, it will be < 0 such that no limit applies, but + // this is the correct behavior as long as w prohibits writing at offsets + // beyond MaxInt64. + return &SectionWriter{w, off, off + n} +} diff --git a/pkg/secio/secio_test.go b/pkg/secio/secio_test.go new file mode 100644 index 000000000..d1d905187 --- /dev/null +++ b/pkg/secio/secio_test.go @@ -0,0 +1,126 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package secio + +import ( + "bytes" + "errors" + "io" + "io/ioutil" + "math" + "testing" +) + +var errEndOfBuffer = errors.New("write beyond end of buffer") + +// buffer resembles bytes.Buffer, but implements io.ReaderAt and io.WriterAt. +// Reads beyond the end of the buffer return io.EOF. Writes beyond the end of +// the buffer return errEndOfBuffer. +type buffer struct { + Bytes []byte +} + +// ReadAt implements io.ReaderAt.ReadAt. +func (b *buffer) ReadAt(dst []byte, off int64) (int, error) { + if off >= int64(len(b.Bytes)) { + return 0, io.EOF + } + n := copy(dst, b.Bytes[off:]) + if n < len(dst) { + return n, io.EOF + } + return n, nil +} + +// WriteAt implements io.WriterAt.WriteAt. +func (b *buffer) WriteAt(src []byte, off int64) (int, error) { + if off >= int64(len(b.Bytes)) { + return 0, errEndOfBuffer + } + n := copy(b.Bytes[off:], src) + if n < len(src) { + return n, errEndOfBuffer + } + return n, nil +} + +func newBufferString(s string) *buffer { + return &buffer{[]byte(s)} +} + +func TestOffsetReader(t *testing.T) { + buf := newBufferString("foobar") + r := NewOffsetReader(buf, 3) + dst, err := ioutil.ReadAll(r) + if want := []byte("bar"); !bytes.Equal(dst, want) || err != nil { + t.Errorf("ReadAll: got (%q, %v), wanted (%q, nil)", dst, err, want) + } +} + +func TestSectionReader(t *testing.T) { + buf := newBufferString("foobarbaz") + r := NewSectionReader(buf, 3, 3) + dst, err := ioutil.ReadAll(r) + if want, wantErr := []byte("bar"), ErrReachedLimit; !bytes.Equal(dst, want) || err != wantErr { + t.Errorf("ReadAll: got (%q, %v), wanted (%q, %v)", dst, err, want, wantErr) + } +} + +func TestSectionReaderLimitOverflow(t *testing.T) { + // SectionReader behaves like OffsetReader when limit overflows int64. + buf := newBufferString("foobar") + r := NewSectionReader(buf, 3, math.MaxInt64) + dst, err := ioutil.ReadAll(r) + if want := []byte("bar"); !bytes.Equal(dst, want) || err != nil { + t.Errorf("ReadAll: got (%q, %v), wanted (%q, nil)", dst, err, want) + } +} + +func TestOffsetWriter(t *testing.T) { + buf := newBufferString("ABCDEF") + w := NewOffsetWriter(buf, 3) + n, err := w.Write([]byte("foobar")) + if wantN, wantErr := 3, errEndOfBuffer; n != wantN || err != wantErr { + t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) + } + if got, want := buf.Bytes, []byte("ABCfoo"); !bytes.Equal(got, want) { + t.Errorf("buf.Bytes: got %q, wanted %q", got, want) + } +} + +func TestSectionWriter(t *testing.T) { + buf := newBufferString("ABCDEFGHI") + w := NewSectionWriter(buf, 3, 3) + n, err := w.Write([]byte("foobar")) + if wantN, wantErr := 3, ErrReachedLimit; n != wantN || err != wantErr { + t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) + } + if got, want := buf.Bytes, []byte("ABCfooGHI"); !bytes.Equal(got, want) { + t.Errorf("buf.Bytes: got %q, wanted %q", got, want) + } +} + +func TestSectionWriterLimitOverflow(t *testing.T) { + // SectionWriter behaves like OffsetWriter when limit overflows int64. + buf := newBufferString("ABCDEF") + w := NewSectionWriter(buf, 3, math.MaxInt64) + n, err := w.Write([]byte("foobar")) + if wantN, wantErr := 3, errEndOfBuffer; n != wantN || err != wantErr { + t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr) + } + if got, want := buf.Bytes, []byte("ABCfoo"); !bytes.Equal(got, want) { + t.Errorf("buf.Bytes: got %q, wanted %q", got, want) + } +} |