summaryrefslogtreecommitdiffhomepage
path: root/pkg/secio
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/secio')
-rw-r--r--pkg/secio/BUILD19
-rw-r--r--pkg/secio/full_reader.go34
-rw-r--r--pkg/secio/secio.go105
-rw-r--r--pkg/secio/secio_test.go126
4 files changed, 284 insertions, 0 deletions
diff --git a/pkg/secio/BUILD b/pkg/secio/BUILD
new file mode 100644
index 000000000..60f63c7a6
--- /dev/null
+++ b/pkg/secio/BUILD
@@ -0,0 +1,19 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "secio",
+ srcs = [
+ "full_reader.go",
+ "secio.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+)
+
+go_test(
+ name = "secio_test",
+ size = "small",
+ srcs = ["secio_test.go"],
+ library = ":secio",
+)
diff --git a/pkg/secio/full_reader.go b/pkg/secio/full_reader.go
new file mode 100644
index 000000000..aed2564bd
--- /dev/null
+++ b/pkg/secio/full_reader.go
@@ -0,0 +1,34 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package secio
+
+import (
+ "io"
+)
+
+// FullReader adapts an io.Reader to never return partial reads with a nil
+// error.
+type FullReader struct {
+ Reader io.Reader
+}
+
+// Read implements io.Reader.Read.
+func (r FullReader) Read(dst []byte) (int, error) {
+ n, err := io.ReadFull(r.Reader, dst)
+ if err == io.ErrUnexpectedEOF {
+ return n, io.EOF
+ }
+ return n, err
+}
diff --git a/pkg/secio/secio.go b/pkg/secio/secio.go
new file mode 100644
index 000000000..b43226035
--- /dev/null
+++ b/pkg/secio/secio.go
@@ -0,0 +1,105 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package secio provides support for sectioned I/O.
+package secio
+
+import (
+ "errors"
+ "io"
+)
+
+// ErrReachedLimit is returned when SectionReader.Read or SectionWriter.Write
+// reaches its limit.
+var ErrReachedLimit = errors.New("reached limit")
+
+// SectionReader implements io.Reader on a section of an underlying io.ReaderAt.
+// It is similar to io.SectionReader, but:
+//
+// - Reading beyond the limit returns ErrReachedLimit, not io.EOF.
+//
+// - Limit overflow is handled correctly.
+type SectionReader struct {
+ r io.ReaderAt
+ off int64
+ limit int64
+}
+
+// Read implements io.Reader.Read.
+func (r *SectionReader) Read(dst []byte) (int, error) {
+ if r.limit >= 0 {
+ if max := r.limit - r.off; max < int64(len(dst)) {
+ dst = dst[:max]
+ }
+ }
+ n, err := r.r.ReadAt(dst, r.off)
+ r.off += int64(n)
+ if err == nil && r.off == r.limit {
+ err = ErrReachedLimit
+ }
+ return n, err
+}
+
+// NewOffsetReader returns an io.Reader that reads from r starting at offset
+// off.
+func NewOffsetReader(r io.ReaderAt, off int64) *SectionReader {
+ return &SectionReader{r, off, -1}
+}
+
+// NewSectionReader returns an io.Reader that reads from r starting at offset
+// off and stops with ErrReachedLimit after n bytes.
+func NewSectionReader(r io.ReaderAt, off int64, n int64) *SectionReader {
+ // If off + n overflows, it will be < 0 such that no limit applies, but
+ // this is the correct behavior as long as r prohibits reading at offsets
+ // beyond MaxInt64.
+ return &SectionReader{r, off, off + n}
+}
+
+// SectionWriter implements io.Writer on a section of an underlying
+// io.WriterAt. Writing beyond the limit returns ErrReachedLimit.
+type SectionWriter struct {
+ w io.WriterAt
+ off int64
+ limit int64
+}
+
+// Write implements io.Writer.Write.
+func (w *SectionWriter) Write(src []byte) (int, error) {
+ if w.limit >= 0 {
+ if max := w.limit - w.off; max < int64(len(src)) {
+ src = src[:max]
+ }
+ }
+ n, err := w.w.WriteAt(src, w.off)
+ w.off += int64(n)
+ if err == nil && w.off == w.limit {
+ err = ErrReachedLimit
+ }
+ return n, err
+}
+
+// NewOffsetWriter returns an io.Writer that writes to w starting at offset
+// off.
+func NewOffsetWriter(w io.WriterAt, off int64) *SectionWriter {
+ return &SectionWriter{w, off, -1}
+}
+
+// NewSectionWriter returns an io.Writer that writes to w starting at offset
+// off and stops with ErrReachedLimit after n bytes.
+func NewSectionWriter(w io.WriterAt, off int64, n int64) *SectionWriter {
+ // If off + n overflows, it will be < 0 such that no limit applies, but
+ // this is the correct behavior as long as w prohibits writing at offsets
+ // beyond MaxInt64.
+ return &SectionWriter{w, off, off + n}
+}
diff --git a/pkg/secio/secio_test.go b/pkg/secio/secio_test.go
new file mode 100644
index 000000000..d1d905187
--- /dev/null
+++ b/pkg/secio/secio_test.go
@@ -0,0 +1,126 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package secio
+
+import (
+ "bytes"
+ "errors"
+ "io"
+ "io/ioutil"
+ "math"
+ "testing"
+)
+
+var errEndOfBuffer = errors.New("write beyond end of buffer")
+
+// buffer resembles bytes.Buffer, but implements io.ReaderAt and io.WriterAt.
+// Reads beyond the end of the buffer return io.EOF. Writes beyond the end of
+// the buffer return errEndOfBuffer.
+type buffer struct {
+ Bytes []byte
+}
+
+// ReadAt implements io.ReaderAt.ReadAt.
+func (b *buffer) ReadAt(dst []byte, off int64) (int, error) {
+ if off >= int64(len(b.Bytes)) {
+ return 0, io.EOF
+ }
+ n := copy(dst, b.Bytes[off:])
+ if n < len(dst) {
+ return n, io.EOF
+ }
+ return n, nil
+}
+
+// WriteAt implements io.WriterAt.WriteAt.
+func (b *buffer) WriteAt(src []byte, off int64) (int, error) {
+ if off >= int64(len(b.Bytes)) {
+ return 0, errEndOfBuffer
+ }
+ n := copy(b.Bytes[off:], src)
+ if n < len(src) {
+ return n, errEndOfBuffer
+ }
+ return n, nil
+}
+
+func newBufferString(s string) *buffer {
+ return &buffer{[]byte(s)}
+}
+
+func TestOffsetReader(t *testing.T) {
+ buf := newBufferString("foobar")
+ r := NewOffsetReader(buf, 3)
+ dst, err := ioutil.ReadAll(r)
+ if want := []byte("bar"); !bytes.Equal(dst, want) || err != nil {
+ t.Errorf("ReadAll: got (%q, %v), wanted (%q, nil)", dst, err, want)
+ }
+}
+
+func TestSectionReader(t *testing.T) {
+ buf := newBufferString("foobarbaz")
+ r := NewSectionReader(buf, 3, 3)
+ dst, err := ioutil.ReadAll(r)
+ if want, wantErr := []byte("bar"), ErrReachedLimit; !bytes.Equal(dst, want) || err != wantErr {
+ t.Errorf("ReadAll: got (%q, %v), wanted (%q, %v)", dst, err, want, wantErr)
+ }
+}
+
+func TestSectionReaderLimitOverflow(t *testing.T) {
+ // SectionReader behaves like OffsetReader when limit overflows int64.
+ buf := newBufferString("foobar")
+ r := NewSectionReader(buf, 3, math.MaxInt64)
+ dst, err := ioutil.ReadAll(r)
+ if want := []byte("bar"); !bytes.Equal(dst, want) || err != nil {
+ t.Errorf("ReadAll: got (%q, %v), wanted (%q, nil)", dst, err, want)
+ }
+}
+
+func TestOffsetWriter(t *testing.T) {
+ buf := newBufferString("ABCDEF")
+ w := NewOffsetWriter(buf, 3)
+ n, err := w.Write([]byte("foobar"))
+ if wantN, wantErr := 3, errEndOfBuffer; n != wantN || err != wantErr {
+ t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
+ }
+ if got, want := buf.Bytes, []byte("ABCfoo"); !bytes.Equal(got, want) {
+ t.Errorf("buf.Bytes: got %q, wanted %q", got, want)
+ }
+}
+
+func TestSectionWriter(t *testing.T) {
+ buf := newBufferString("ABCDEFGHI")
+ w := NewSectionWriter(buf, 3, 3)
+ n, err := w.Write([]byte("foobar"))
+ if wantN, wantErr := 3, ErrReachedLimit; n != wantN || err != wantErr {
+ t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
+ }
+ if got, want := buf.Bytes, []byte("ABCfooGHI"); !bytes.Equal(got, want) {
+ t.Errorf("buf.Bytes: got %q, wanted %q", got, want)
+ }
+}
+
+func TestSectionWriterLimitOverflow(t *testing.T) {
+ // SectionWriter behaves like OffsetWriter when limit overflows int64.
+ buf := newBufferString("ABCDEF")
+ w := NewSectionWriter(buf, 3, math.MaxInt64)
+ n, err := w.Write([]byte("foobar"))
+ if wantN, wantErr := 3, errEndOfBuffer; n != wantN || err != wantErr {
+ t.Errorf("WriteAt: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
+ }
+ if got, want := buf.Bytes, []byte("ABCfoo"); !bytes.Equal(got, want) {
+ t.Errorf("buf.Bytes: got %q, wanted %q", got, want)
+ }
+}