summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/safemem/io_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/safemem/io_test.go')
-rw-r--r--pkg/sentry/safemem/io_test.go199
1 files changed, 199 insertions, 0 deletions
diff --git a/pkg/sentry/safemem/io_test.go b/pkg/sentry/safemem/io_test.go
new file mode 100644
index 000000000..edac4c1d7
--- /dev/null
+++ b/pkg/sentry/safemem/io_test.go
@@ -0,0 +1,199 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package safemem
+
+import (
+ "bytes"
+ "io"
+ "testing"
+)
+
+func makeBlocks(slices ...[]byte) []Block {
+ blocks := make([]Block, 0, len(slices))
+ for _, s := range slices {
+ blocks = append(blocks, BlockFromSafeSlice(s))
+ }
+ return blocks
+}
+
+func TestFromIOReaderFullRead(t *testing.T) {
+ r := FromIOReader{bytes.NewBufferString("foobar")}
+ dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
+ n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
+ if wantN := uint64(6); n != wantN || err != nil {
+ t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ for i, want := range [][]byte{[]byte("foo"), []byte("bar")} {
+ if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
+ t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
+ }
+ }
+}
+
+type eofHidingReader struct {
+ Reader io.Reader
+}
+
+func (r eofHidingReader) Read(dst []byte) (int, error) {
+ n, err := r.Reader.Read(dst)
+ if err == io.EOF {
+ return n, nil
+ }
+ return n, err
+}
+
+func TestFromIOReaderPartialRead(t *testing.T) {
+ r := FromIOReader{eofHidingReader{bytes.NewBufferString("foob")}}
+ dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
+ n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
+ // FromIOReader should stop after the eofHidingReader returns (1, nil)
+ // for a 3-byte read.
+ if wantN := uint64(4); n != wantN || err != nil {
+ t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ for i, want := range [][]byte{[]byte("foo"), []byte("b\x00\x00")} {
+ if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
+ t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
+ }
+ }
+}
+
+type singleByteReader struct {
+ Reader io.Reader
+}
+
+func (r singleByteReader) Read(dst []byte) (int, error) {
+ if len(dst) == 0 {
+ return r.Reader.Read(dst)
+ }
+ return r.Reader.Read(dst[:1])
+}
+
+func TestSingleByteReader(t *testing.T) {
+ r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}}
+ dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
+ n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
+ // FromIOReader should stop after the singleByteReader returns (1, nil)
+ // for a 3-byte read.
+ if wantN := uint64(1); n != wantN || err != nil {
+ t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ for i, want := range [][]byte{[]byte("f\x00\x00"), []byte("\x00\x00\x00")} {
+ if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
+ t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
+ }
+ }
+}
+
+func TestReadFullToBlocks(t *testing.T) {
+ r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}}
+ dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
+ n, err := ReadFullToBlocks(r, BlockSeqFromSlice(dsts))
+ // ReadFullToBlocks should call into FromIOReader => singleByteReader
+ // repeatedly until dsts is exhausted.
+ if wantN := uint64(6); n != wantN || err != nil {
+ t.Errorf("ReadFullToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ for i, want := range [][]byte{[]byte("foo"), []byte("bar")} {
+ if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
+ t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
+ }
+ }
+}
+
+func TestFromIOWriterFullWrite(t *testing.T) {
+ srcs := makeBlocks([]byte("foo"), []byte("bar"))
+ var dst bytes.Buffer
+ w := FromIOWriter{&dst}
+ n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
+ if wantN := uint64(6); n != wantN || err != nil {
+ t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) {
+ t.Errorf("dst: got %q, wanted %q", got, want)
+ }
+}
+
+type limitedWriter struct {
+ Writer io.Writer
+ Done int
+ Limit int
+}
+
+func (w *limitedWriter) Write(src []byte) (int, error) {
+ count := len(src)
+ if count > (w.Limit - w.Done) {
+ count = w.Limit - w.Done
+ }
+ n, err := w.Writer.Write(src[:count])
+ w.Done += n
+ return n, err
+}
+
+func TestFromIOWriterPartialWrite(t *testing.T) {
+ srcs := makeBlocks([]byte("foo"), []byte("bar"))
+ var dst bytes.Buffer
+ w := FromIOWriter{&limitedWriter{&dst, 0, 4}}
+ n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
+ // FromIOWriter should stop after the limitedWriter returns (1, nil) for a
+ // 3-byte write.
+ if wantN := uint64(4); n != wantN || err != nil {
+ t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) {
+ t.Errorf("dst: got %q, wanted %q", got, want)
+ }
+}
+
+type singleByteWriter struct {
+ Writer io.Writer
+}
+
+func (w singleByteWriter) Write(src []byte) (int, error) {
+ if len(src) == 0 {
+ return w.Writer.Write(src)
+ }
+ return w.Writer.Write(src[:1])
+}
+
+func TestSingleByteWriter(t *testing.T) {
+ srcs := makeBlocks([]byte("foo"), []byte("bar"))
+ var dst bytes.Buffer
+ w := FromIOWriter{singleByteWriter{&dst}}
+ n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
+ // FromIOWriter should stop after the singleByteWriter returns (1, nil)
+ // for a 3-byte write.
+ if wantN := uint64(1); n != wantN || err != nil {
+ t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := dst.Bytes(), []byte("f"); !bytes.Equal(got, want) {
+ t.Errorf("dst: got %q, wanted %q", got, want)
+ }
+}
+
+func TestWriteFullToBlocks(t *testing.T) {
+ srcs := makeBlocks([]byte("foo"), []byte("bar"))
+ var dst bytes.Buffer
+ w := FromIOWriter{singleByteWriter{&dst}}
+ n, err := WriteFullFromBlocks(w, BlockSeqFromSlice(srcs))
+ // WriteFullToBlocks should call into FromIOWriter => singleByteWriter
+ // repeatedly until srcs is exhausted.
+ if wantN := uint64(6); n != wantN || err != nil {
+ t.Errorf("WriteFullFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) {
+ t.Errorf("dst: got %q, wanted %q", got, want)
+ }
+}