summaryrefslogtreecommitdiffhomepage
path: root/pkg/safemem/io_test.go
diff options
context:
space:
mode:
authorAdin Scannell <ascannell@google.com>2020-01-27 15:17:58 -0800
committergVisor bot <gvisor-bot@google.com>2020-01-27 15:31:32 -0800
commit0e2f1b7abd219f39d67cc2cecd00c441a13eeb29 (patch)
treee2199fc3ce8eb441dac9904d3df9d01cacf7edd7 /pkg/safemem/io_test.go
parent60d7ff73e1a788749d76fc9402836ae3a04053a8 (diff)
Update package locations.
Because the abi will depend on the core types for marshalling (usermem, context, safemem, safecopy), these need to be flattened from the sentry directory. These packages contain no sentry-specific details. PiperOrigin-RevId: 291811289
Diffstat (limited to 'pkg/safemem/io_test.go')
-rw-r--r--pkg/safemem/io_test.go199
1 files changed, 199 insertions, 0 deletions
diff --git a/pkg/safemem/io_test.go b/pkg/safemem/io_test.go
new file mode 100644
index 000000000..629741bee
--- /dev/null
+++ b/pkg/safemem/io_test.go
@@ -0,0 +1,199 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package safemem
+
+import (
+ "bytes"
+ "io"
+ "testing"
+)
+
+func makeBlocks(slices ...[]byte) []Block {
+ blocks := make([]Block, 0, len(slices))
+ for _, s := range slices {
+ blocks = append(blocks, BlockFromSafeSlice(s))
+ }
+ return blocks
+}
+
+func TestFromIOReaderFullRead(t *testing.T) {
+ r := FromIOReader{bytes.NewBufferString("foobar")}
+ dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
+ n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
+ if wantN := uint64(6); n != wantN || err != nil {
+ t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ for i, want := range [][]byte{[]byte("foo"), []byte("bar")} {
+ if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
+ t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
+ }
+ }
+}
+
+type eofHidingReader struct {
+ Reader io.Reader
+}
+
+func (r eofHidingReader) Read(dst []byte) (int, error) {
+ n, err := r.Reader.Read(dst)
+ if err == io.EOF {
+ return n, nil
+ }
+ return n, err
+}
+
+func TestFromIOReaderPartialRead(t *testing.T) {
+ r := FromIOReader{eofHidingReader{bytes.NewBufferString("foob")}}
+ dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
+ n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
+ // FromIOReader should stop after the eofHidingReader returns (1, nil)
+ // for a 3-byte read.
+ if wantN := uint64(4); n != wantN || err != nil {
+ t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ for i, want := range [][]byte{[]byte("foo"), []byte("b\x00\x00")} {
+ if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
+ t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
+ }
+ }
+}
+
+type singleByteReader struct {
+ Reader io.Reader
+}
+
+func (r singleByteReader) Read(dst []byte) (int, error) {
+ if len(dst) == 0 {
+ return r.Reader.Read(dst)
+ }
+ return r.Reader.Read(dst[:1])
+}
+
+func TestSingleByteReader(t *testing.T) {
+ r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}}
+ dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
+ n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts))
+ // FromIOReader should stop after the singleByteReader returns (1, nil)
+ // for a 3-byte read.
+ if wantN := uint64(1); n != wantN || err != nil {
+ t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ for i, want := range [][]byte{[]byte("f\x00\x00"), []byte("\x00\x00\x00")} {
+ if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
+ t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
+ }
+ }
+}
+
+func TestReadFullToBlocks(t *testing.T) {
+ r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}}
+ dsts := makeBlocks(make([]byte, 3), make([]byte, 3))
+ n, err := ReadFullToBlocks(r, BlockSeqFromSlice(dsts))
+ // ReadFullToBlocks should call into FromIOReader => singleByteReader
+ // repeatedly until dsts is exhausted.
+ if wantN := uint64(6); n != wantN || err != nil {
+ t.Errorf("ReadFullToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ for i, want := range [][]byte{[]byte("foo"), []byte("bar")} {
+ if got := dsts[i].ToSlice(); !bytes.Equal(got, want) {
+ t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want)
+ }
+ }
+}
+
+func TestFromIOWriterFullWrite(t *testing.T) {
+ srcs := makeBlocks([]byte("foo"), []byte("bar"))
+ var dst bytes.Buffer
+ w := FromIOWriter{&dst}
+ n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
+ if wantN := uint64(6); n != wantN || err != nil {
+ t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) {
+ t.Errorf("dst: got %q, wanted %q", got, want)
+ }
+}
+
+type limitedWriter struct {
+ Writer io.Writer
+ Done int
+ Limit int
+}
+
+func (w *limitedWriter) Write(src []byte) (int, error) {
+ count := len(src)
+ if count > (w.Limit - w.Done) {
+ count = w.Limit - w.Done
+ }
+ n, err := w.Writer.Write(src[:count])
+ w.Done += n
+ return n, err
+}
+
+func TestFromIOWriterPartialWrite(t *testing.T) {
+ srcs := makeBlocks([]byte("foo"), []byte("bar"))
+ var dst bytes.Buffer
+ w := FromIOWriter{&limitedWriter{&dst, 0, 4}}
+ n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
+ // FromIOWriter should stop after the limitedWriter returns (1, nil) for a
+ // 3-byte write.
+ if wantN := uint64(4); n != wantN || err != nil {
+ t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) {
+ t.Errorf("dst: got %q, wanted %q", got, want)
+ }
+}
+
+type singleByteWriter struct {
+ Writer io.Writer
+}
+
+func (w singleByteWriter) Write(src []byte) (int, error) {
+ if len(src) == 0 {
+ return w.Writer.Write(src)
+ }
+ return w.Writer.Write(src[:1])
+}
+
+func TestSingleByteWriter(t *testing.T) {
+ srcs := makeBlocks([]byte("foo"), []byte("bar"))
+ var dst bytes.Buffer
+ w := FromIOWriter{singleByteWriter{&dst}}
+ n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs))
+ // FromIOWriter should stop after the singleByteWriter returns (1, nil)
+ // for a 3-byte write.
+ if wantN := uint64(1); n != wantN || err != nil {
+ t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := dst.Bytes(), []byte("f"); !bytes.Equal(got, want) {
+ t.Errorf("dst: got %q, wanted %q", got, want)
+ }
+}
+
+func TestWriteFullToBlocks(t *testing.T) {
+ srcs := makeBlocks([]byte("foo"), []byte("bar"))
+ var dst bytes.Buffer
+ w := FromIOWriter{singleByteWriter{&dst}}
+ n, err := WriteFullFromBlocks(w, BlockSeqFromSlice(srcs))
+ // WriteFullToBlocks should call into FromIOWriter => singleByteWriter
+ // repeatedly until srcs is exhausted.
+ if wantN := uint64(6); n != wantN || err != nil {
+ t.Errorf("WriteFullFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN)
+ }
+ if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) {
+ t.Errorf("dst: got %q, wanted %q", got, want)
+ }
+}