diff options
Diffstat (limited to 'pkg/sentry/safemem')
-rw-r--r-- | pkg/sentry/safemem/BUILD | 28 | ||||
-rw-r--r-- | pkg/sentry/safemem/block_unsafe.go | 269 | ||||
-rw-r--r-- | pkg/sentry/safemem/io.go | 339 | ||||
-rw-r--r-- | pkg/sentry/safemem/io_test.go | 199 | ||||
-rw-r--r-- | pkg/sentry/safemem/safemem.go | 16 | ||||
-rw-r--r-- | pkg/sentry/safemem/seq_test.go | 196 | ||||
-rw-r--r-- | pkg/sentry/safemem/seq_unsafe.go | 299 |
7 files changed, 1346 insertions, 0 deletions
diff --git a/pkg/sentry/safemem/BUILD b/pkg/sentry/safemem/BUILD new file mode 100644 index 000000000..dc4cfce41 --- /dev/null +++ b/pkg/sentry/safemem/BUILD @@ -0,0 +1,28 @@ +package(licenses = ["notice"]) # Apache 2.0 + +load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test") + +go_library( + name = "safemem", + srcs = [ + "block_unsafe.go", + "io.go", + "safemem.go", + "seq_unsafe.go", + ], + importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/safemem", + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/sentry/platform/safecopy", + ], +) + +go_test( + name = "safemem_test", + size = "small", + srcs = [ + "io_test.go", + "seq_test.go", + ], + embed = [":safemem"], +) diff --git a/pkg/sentry/safemem/block_unsafe.go b/pkg/sentry/safemem/block_unsafe.go new file mode 100644 index 000000000..0b58f6497 --- /dev/null +++ b/pkg/sentry/safemem/block_unsafe.go @@ -0,0 +1,269 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package safemem + +import ( + "fmt" + "reflect" + "unsafe" + + "gvisor.googlesource.com/gvisor/pkg/sentry/platform/safecopy" +) + +// A Block is a range of contiguous bytes, similar to []byte but with the +// following differences: +// +// - The memory represented by a Block may require the use of safecopy to +// access. +// +// - Block does not carry a capacity and cannot be expanded. +// +// Blocks are immutable and may be copied by value. The zero value of Block +// represents an empty range, analogous to a nil []byte. +type Block struct { + // [start, start+length) is the represented memory. + // + // start is an unsafe.Pointer to ensure that Block prevents the represented + // memory from being garbage-collected. + start unsafe.Pointer + length int + + // needSafecopy is true if accessing the represented memory requires the + // use of safecopy. + needSafecopy bool +} + +// BlockFromSafeSlice returns a Block equivalent to slice, which is safe to +// access without safecopy. +func BlockFromSafeSlice(slice []byte) Block { + return blockFromSlice(slice, false) +} + +// BlockFromUnsafeSlice returns a Block equivalent to bs, which is not safe to +// access without safecopy. +func BlockFromUnsafeSlice(slice []byte) Block { + return blockFromSlice(slice, true) +} + +func blockFromSlice(slice []byte, needSafecopy bool) Block { + if len(slice) == 0 { + return Block{} + } + return Block{ + start: unsafe.Pointer(&slice[0]), + length: len(slice), + needSafecopy: needSafecopy, + } +} + +// BlockFromSafePointer returns a Block equivalent to [ptr, ptr+len), which is +// safe to access without safecopy. +// +// Preconditions: ptr+len does not overflow. +func BlockFromSafePointer(ptr unsafe.Pointer, len int) Block { + return blockFromPointer(ptr, len, false) +} + +// BlockFromUnsafePointer returns a Block equivalent to [ptr, ptr+len), which +// is not safe to access without safecopy. +// +// Preconditions: ptr+len does not overflow. +func BlockFromUnsafePointer(ptr unsafe.Pointer, len int) Block { + return blockFromPointer(ptr, len, true) +} + +func blockFromPointer(ptr unsafe.Pointer, len int, needSafecopy bool) Block { + if uptr := uintptr(ptr); uptr+uintptr(len) < uptr { + panic(fmt.Sprintf("ptr %#x + len %#x overflows", ptr, len)) + } + return Block{ + start: ptr, + length: len, + needSafecopy: needSafecopy, + } +} + +// DropFirst returns a Block equivalent to b, but with the first n bytes +// omitted. It is analogous to the [n:] operation on a slice, except that if n +// > b.Len(), DropFirst returns an empty Block instead of panicking. +// +// Preconditions: n >= 0. +func (b Block) DropFirst(n int) Block { + if n < 0 { + panic(fmt.Sprintf("invalid n: %d", n)) + } + return b.DropFirst64(uint64(n)) +} + +// DropFirst64 is equivalent to DropFirst but takes a uint64. +func (b Block) DropFirst64(n uint64) Block { + if n >= uint64(b.length) { + return Block{} + } + return Block{ + start: unsafe.Pointer(uintptr(b.start) + uintptr(n)), + length: b.length - int(n), + needSafecopy: b.needSafecopy, + } +} + +// TakeFirst returns a Block equivalent to the first n bytes of b. It is +// analogous to the [:n] operation on a slice, except that if n > b.Len(), +// TakeFirst returns a copy of b instead of panicking. +// +// Preconditions: n >= 0. +func (b Block) TakeFirst(n int) Block { + if n < 0 { + panic(fmt.Sprintf("invalid n: %d", n)) + } + return b.TakeFirst64(uint64(n)) +} + +// TakeFirst64 is equivalent to TakeFirst but takes a uint64. +func (b Block) TakeFirst64(n uint64) Block { + if n == 0 { + return Block{} + } + if n >= uint64(b.length) { + return b + } + return Block{ + start: b.start, + length: int(n), + needSafecopy: b.needSafecopy, + } +} + +// ToSlice returns a []byte equivalent to b. +func (b Block) ToSlice() []byte { + var bs []byte + hdr := (*reflect.SliceHeader)(unsafe.Pointer(&bs)) + hdr.Data = uintptr(b.start) + hdr.Len = b.length + hdr.Cap = b.length + return bs +} + +// Addr returns b's start address as a uintptr. It returns uintptr instead of +// unsafe.Pointer so that code using safemem cannot obtain unsafe.Pointers +// without importing the unsafe package explicitly. +// +// Note that a uintptr is not recognized as a pointer by the garbage collector, +// such that if there are no uses of b after a call to b.Addr() and the address +// is to Go-managed memory, the returned uintptr does not prevent garbage +// collection of the pointee. +func (b Block) Addr() uintptr { + return uintptr(b.start) +} + +// Len returns b's length in bytes. +func (b Block) Len() int { + return b.length +} + +// NeedSafecopy returns true if accessing b.ToSlice() requires the use of safecopy. +func (b Block) NeedSafecopy() bool { + return b.needSafecopy +} + +// String implements fmt.Stringer.String. +func (b Block) String() string { + if uintptr(b.start) == 0 && b.length == 0 { + return "<nil>" + } + var suffix string + if b.needSafecopy { + suffix = "*" + } + return fmt.Sprintf("[%#x-%#x)%s", uintptr(b.start), uintptr(b.start)+uintptr(b.length), suffix) +} + +// Copy copies src.Len() or dst.Len() bytes, whichever is less, from src +// to dst and returns the number of bytes copied. +// +// If src and dst overlap, the data stored in dst is unspecified. +func Copy(dst, src Block) (int, error) { + if !dst.needSafecopy && !src.needSafecopy { + return copy(dst.ToSlice(), src.ToSlice()), nil + } + + n := dst.length + if n > src.length { + n = src.length + } + if n == 0 { + return 0, nil + } + + switch { + case dst.needSafecopy && !src.needSafecopy: + return safecopy.CopyOut(dst.start, src.TakeFirst(n).ToSlice()) + case !dst.needSafecopy && src.needSafecopy: + return safecopy.CopyIn(dst.TakeFirst(n).ToSlice(), src.start) + case dst.needSafecopy && src.needSafecopy: + n64, err := safecopy.Copy(dst.start, src.start, uintptr(n)) + return int(n64), err + default: + panic("unreachable") + } +} + +// Zero sets all bytes in dst to 0 and returns the number of bytes zeroed. +func Zero(dst Block) (int, error) { + if !dst.needSafecopy { + bs := dst.ToSlice() + for i := range bs { + bs[i] = 0 + } + return len(bs), nil + } + + n64, err := safecopy.ZeroOut(dst.start, uintptr(dst.length)) + return int(n64), err +} + +// Safecopy atomics are no slower than non-safecopy atomics, so use the former +// even when !b.needSafecopy to get consistent alignment checking. + +// SwapUint32 invokes safecopy.SwapUint32 on the first 4 bytes of b. +// +// Preconditions: b.Len() >= 4. +func SwapUint32(b Block, new uint32) (uint32, error) { + if b.length < 4 { + panic(fmt.Sprintf("insufficient length: %d", b.length)) + } + return safecopy.SwapUint32(b.start, new) +} + +// SwapUint64 invokes safecopy.SwapUint64 on the first 8 bytes of b. +// +// Preconditions: b.Len() >= 8. +func SwapUint64(b Block, new uint64) (uint64, error) { + if b.length < 8 { + panic(fmt.Sprintf("insufficient length: %d", b.length)) + } + return safecopy.SwapUint64(b.start, new) +} + +// CompareAndSwapUint32 invokes safecopy.CompareAndSwapUint32 on the first 4 +// bytes of b. +// +// Preconditions: b.Len() >= 4. +func CompareAndSwapUint32(b Block, old, new uint32) (uint32, error) { + if b.length < 4 { + panic(fmt.Sprintf("insufficient length: %d", b.length)) + } + return safecopy.CompareAndSwapUint32(b.start, old, new) +} diff --git a/pkg/sentry/safemem/io.go b/pkg/sentry/safemem/io.go new file mode 100644 index 000000000..fd917648b --- /dev/null +++ b/pkg/sentry/safemem/io.go @@ -0,0 +1,339 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package safemem + +import ( + "errors" + "io" + "math" +) + +// ErrEndOfBlockSeq is returned by BlockSeqWriter when attempting to write +// beyond the end of the BlockSeq. +var ErrEndOfBlockSeq = errors.New("write beyond end of BlockSeq") + +// Reader represents a streaming byte source like io.Reader. +type Reader interface { + // ReadToBlocks reads up to dsts.NumBytes() bytes into dsts and returns the + // number of bytes read. It may return a partial read without an error + // (i.e. (n, nil) where 0 < n < dsts.NumBytes()). It should not return a + // full read with an error (i.e. (dsts.NumBytes(), err) where err != nil); + // note that this differs from io.Reader.Read (in particular, io.EOF should + // not be returned if ReadToBlocks successfully reads dsts.NumBytes() + // bytes.) + ReadToBlocks(dsts BlockSeq) (uint64, error) +} + +// Writer represents a streaming byte sink like io.Writer. +type Writer interface { + // WriteFromBlocks writes up to srcs.NumBytes() bytes from srcs and returns + // the number of bytes written. It may return a partial write without an + // error (i.e. (n, nil) where 0 < n < srcs.NumBytes()). It should not + // return a full write with an error (i.e. srcs.NumBytes(), err) where err + // != nil). + WriteFromBlocks(srcs BlockSeq) (uint64, error) +} + +// ReadFullToBlocks repeatedly invokes r.ReadToBlocks until dsts.NumBytes() +// bytes have been read or ReadToBlocks returns an error. +func ReadFullToBlocks(r Reader, dsts BlockSeq) (uint64, error) { + var done uint64 + for !dsts.IsEmpty() { + n, err := r.ReadToBlocks(dsts) + done += n + if err != nil { + return done, err + } + dsts = dsts.DropFirst64(n) + } + return done, nil +} + +// WriteFullFromBlocks repeatedly invokes w.WriteFromBlocks until +// srcs.NumBytes() bytes have been written or WriteFromBlocks returns an error. +func WriteFullFromBlocks(w Writer, srcs BlockSeq) (uint64, error) { + var done uint64 + for !srcs.IsEmpty() { + n, err := w.WriteFromBlocks(srcs) + done += n + if err != nil { + return done, err + } + srcs = srcs.DropFirst64(n) + } + return done, nil +} + +// BlockSeqReader implements Reader by reading from a BlockSeq. +type BlockSeqReader struct { + Blocks BlockSeq +} + +// ReadToBlocks implements Reader.ReadToBlocks. +func (r *BlockSeqReader) ReadToBlocks(dsts BlockSeq) (uint64, error) { + n, err := CopySeq(dsts, r.Blocks) + r.Blocks = r.Blocks.DropFirst64(n) + if err != nil { + return n, err + } + if n < dsts.NumBytes() { + return n, io.EOF + } + return n, nil +} + +// BlockSeqWriter implements Writer by writing to a BlockSeq. +type BlockSeqWriter struct { + Blocks BlockSeq +} + +// WriteFromBlocks implements Writer.WriteFromBlocks. +func (w *BlockSeqWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) { + n, err := CopySeq(w.Blocks, srcs) + w.Blocks = w.Blocks.DropFirst64(n) + if err != nil { + return n, err + } + if n < srcs.NumBytes() { + return n, ErrEndOfBlockSeq + } + return n, nil +} + +// ReaderFunc implements Reader for a function with the semantics of +// Reader.ReadToBlocks. +type ReaderFunc func(dsts BlockSeq) (uint64, error) + +// ReadToBlocks implements Reader.ReadToBlocks. +func (f ReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) { + return f(dsts) +} + +// WriterFunc implements Writer for a function with the semantics of +// Writer.WriteFromBlocks. +type WriterFunc func(srcs BlockSeq) (uint64, error) + +// WriteFromBlocks implements Writer.WriteFromBlocks. +func (f WriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) { + return f(srcs) +} + +// ToIOReader implements io.Reader for a (safemem.)Reader. +// +// ToIOReader will return a successful partial read iff Reader.ReadToBlocks does +// so. +type ToIOReader struct { + Reader Reader +} + +// Read implements io.Reader.Read. +func (r ToIOReader) Read(dst []byte) (int, error) { + n, err := r.Reader.ReadToBlocks(BlockSeqOf(BlockFromSafeSlice(dst))) + return int(n), err +} + +// ToIOWriter implements io.Writer for a (safemem.)Writer. +type ToIOWriter struct { + Writer Writer +} + +// Write implements io.Writer.Write. +func (w ToIOWriter) Write(src []byte) (int, error) { + // io.Writer does not permit partial writes. + n, err := WriteFullFromBlocks(w.Writer, BlockSeqOf(BlockFromSafeSlice(src))) + return int(n), err +} + +// FromIOReader implements Reader for an io.Reader by repeatedly invoking +// io.Reader.Read until it returns an error or partial read. +// +// FromIOReader will return a successful partial read iff Reader.Read does so. +type FromIOReader struct { + Reader io.Reader +} + +// ReadToBlocks implements Reader.ReadToBlocks. +func (r FromIOReader) ReadToBlocks(dsts BlockSeq) (uint64, error) { + var buf []byte + var done uint64 + for !dsts.IsEmpty() { + dst := dsts.Head() + var n int + var err error + n, buf, err = r.readToBlock(dst, buf) + done += uint64(n) + if n != dst.Len() { + return done, err + } + dsts = dsts.Tail() + if err != nil { + if dsts.IsEmpty() && err == io.EOF { + return done, nil + } + return done, err + } + } + return done, nil +} + +func (r FromIOReader) readToBlock(dst Block, buf []byte) (int, []byte, error) { + // io.Reader isn't safecopy-aware, so we have to buffer Blocks that require + // safecopy. + if !dst.NeedSafecopy() { + n, err := r.Reader.Read(dst.ToSlice()) + return n, buf, err + } + if len(buf) < dst.Len() { + buf = make([]byte, dst.Len()) + } + rn, rerr := r.Reader.Read(buf[:dst.Len()]) + wbn, wberr := Copy(dst, BlockFromSafeSlice(buf[:rn])) + if wberr != nil { + return wbn, buf, wberr + } + return wbn, buf, rerr +} + +// FromIOWriter implements Writer for an io.Writer by repeatedly invoking +// io.Writer.Write until it returns an error or partial write. +// +// FromIOWriter will tolerate implementations of io.Writer.Write that return +// partial writes with a nil error in contravention of io.Writer's +// requirements, since Writer is permitted to do so. FromIOWriter will return a +// successful partial write iff Writer.Write does so. +type FromIOWriter struct { + Writer io.Writer +} + +// WriteFromBlocks implements Writer.WriteFromBlocks. +func (w FromIOWriter) WriteFromBlocks(srcs BlockSeq) (uint64, error) { + var buf []byte + var done uint64 + for !srcs.IsEmpty() { + src := srcs.Head() + var n int + var err error + n, buf, err = w.writeFromBlock(src, buf) + done += uint64(n) + if n != src.Len() || err != nil { + return done, err + } + srcs = srcs.Tail() + } + return done, nil +} + +func (w FromIOWriter) writeFromBlock(src Block, buf []byte) (int, []byte, error) { + // io.Writer isn't safecopy-aware, so we have to buffer Blocks that require + // safecopy. + if !src.NeedSafecopy() { + n, err := w.Writer.Write(src.ToSlice()) + return n, buf, err + } + if len(buf) < src.Len() { + buf = make([]byte, src.Len()) + } + bufn, buferr := Copy(BlockFromSafeSlice(buf[:src.Len()]), src) + wn, werr := w.Writer.Write(buf[:bufn]) + if werr != nil { + return wn, buf, werr + } + return wn, buf, buferr +} + +// FromVecReaderFunc implements Reader for a function that reads data into a +// [][]byte and returns the number of bytes read as an int64. +type FromVecReaderFunc struct { + ReadVec func(dsts [][]byte) (int64, error) +} + +// ReadToBlocks implements Reader.ReadToBlocks. +// +// ReadToBlocks calls r.ReadVec at most once. +func (r FromVecReaderFunc) ReadToBlocks(dsts BlockSeq) (uint64, error) { + if dsts.IsEmpty() { + return 0, nil + } + // Ensure that we don't pass a [][]byte with a total length > MaxInt64. + dsts = dsts.TakeFirst64(uint64(math.MaxInt64)) + dstSlices := make([][]byte, 0, dsts.NumBlocks()) + // Buffer Blocks that require safecopy. + for tmp := dsts; !tmp.IsEmpty(); tmp = tmp.Tail() { + dst := tmp.Head() + if dst.NeedSafecopy() { + dstSlices = append(dstSlices, make([]byte, dst.Len())) + } else { + dstSlices = append(dstSlices, dst.ToSlice()) + } + } + rn, rerr := r.ReadVec(dstSlices) + dsts = dsts.TakeFirst64(uint64(rn)) + var done uint64 + var i int + for !dsts.IsEmpty() { + dst := dsts.Head() + if dst.NeedSafecopy() { + n, err := Copy(dst, BlockFromSafeSlice(dstSlices[i])) + done += uint64(n) + if err != nil { + return done, err + } + } else { + done += uint64(dst.Len()) + } + dsts = dsts.Tail() + i++ + } + return done, rerr +} + +// FromVecWriterFunc implements Writer for a function that writes data from a +// [][]byte and returns the number of bytes written. +type FromVecWriterFunc struct { + WriteVec func(srcs [][]byte) (int64, error) +} + +// WriteFromBlocks implements Writer.WriteFromBlocks. +// +// WriteFromBlocks calls w.WriteVec at most once. +func (w FromVecWriterFunc) WriteFromBlocks(srcs BlockSeq) (uint64, error) { + if srcs.IsEmpty() { + return 0, nil + } + // Ensure that we don't pass a [][]byte with a total length > MaxInt64. + srcs = srcs.TakeFirst64(uint64(math.MaxInt64)) + srcSlices := make([][]byte, 0, srcs.NumBlocks()) + // Buffer Blocks that require safecopy. + var buferr error + for tmp := srcs; !tmp.IsEmpty(); tmp = tmp.Tail() { + src := tmp.Head() + if src.NeedSafecopy() { + slice := make([]byte, src.Len()) + n, err := Copy(BlockFromSafeSlice(slice), src) + srcSlices = append(srcSlices, slice[:n]) + if err != nil { + buferr = err + break + } + } else { + srcSlices = append(srcSlices, src.ToSlice()) + } + } + n, err := w.WriteVec(srcSlices) + if err != nil { + return uint64(n), err + } + return uint64(n), buferr +} diff --git a/pkg/sentry/safemem/io_test.go b/pkg/sentry/safemem/io_test.go new file mode 100644 index 000000000..edac4c1d7 --- /dev/null +++ b/pkg/sentry/safemem/io_test.go @@ -0,0 +1,199 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package safemem + +import ( + "bytes" + "io" + "testing" +) + +func makeBlocks(slices ...[]byte) []Block { + blocks := make([]Block, 0, len(slices)) + for _, s := range slices { + blocks = append(blocks, BlockFromSafeSlice(s)) + } + return blocks +} + +func TestFromIOReaderFullRead(t *testing.T) { + r := FromIOReader{bytes.NewBufferString("foobar")} + dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) + n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) + if wantN := uint64(6); n != wantN || err != nil { + t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + for i, want := range [][]byte{[]byte("foo"), []byte("bar")} { + if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { + t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) + } + } +} + +type eofHidingReader struct { + Reader io.Reader +} + +func (r eofHidingReader) Read(dst []byte) (int, error) { + n, err := r.Reader.Read(dst) + if err == io.EOF { + return n, nil + } + return n, err +} + +func TestFromIOReaderPartialRead(t *testing.T) { + r := FromIOReader{eofHidingReader{bytes.NewBufferString("foob")}} + dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) + n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) + // FromIOReader should stop after the eofHidingReader returns (1, nil) + // for a 3-byte read. + if wantN := uint64(4); n != wantN || err != nil { + t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + for i, want := range [][]byte{[]byte("foo"), []byte("b\x00\x00")} { + if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { + t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) + } + } +} + +type singleByteReader struct { + Reader io.Reader +} + +func (r singleByteReader) Read(dst []byte) (int, error) { + if len(dst) == 0 { + return r.Reader.Read(dst) + } + return r.Reader.Read(dst[:1]) +} + +func TestSingleByteReader(t *testing.T) { + r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}} + dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) + n, err := r.ReadToBlocks(BlockSeqFromSlice(dsts)) + // FromIOReader should stop after the singleByteReader returns (1, nil) + // for a 3-byte read. + if wantN := uint64(1); n != wantN || err != nil { + t.Errorf("ReadToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + for i, want := range [][]byte{[]byte("f\x00\x00"), []byte("\x00\x00\x00")} { + if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { + t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) + } + } +} + +func TestReadFullToBlocks(t *testing.T) { + r := FromIOReader{singleByteReader{bytes.NewBufferString("foobar")}} + dsts := makeBlocks(make([]byte, 3), make([]byte, 3)) + n, err := ReadFullToBlocks(r, BlockSeqFromSlice(dsts)) + // ReadFullToBlocks should call into FromIOReader => singleByteReader + // repeatedly until dsts is exhausted. + if wantN := uint64(6); n != wantN || err != nil { + t.Errorf("ReadFullToBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + for i, want := range [][]byte{[]byte("foo"), []byte("bar")} { + if got := dsts[i].ToSlice(); !bytes.Equal(got, want) { + t.Errorf("dsts[%d]: got %q, wanted %q", i, got, want) + } + } +} + +func TestFromIOWriterFullWrite(t *testing.T) { + srcs := makeBlocks([]byte("foo"), []byte("bar")) + var dst bytes.Buffer + w := FromIOWriter{&dst} + n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) + if wantN := uint64(6); n != wantN || err != nil { + t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) { + t.Errorf("dst: got %q, wanted %q", got, want) + } +} + +type limitedWriter struct { + Writer io.Writer + Done int + Limit int +} + +func (w *limitedWriter) Write(src []byte) (int, error) { + count := len(src) + if count > (w.Limit - w.Done) { + count = w.Limit - w.Done + } + n, err := w.Writer.Write(src[:count]) + w.Done += n + return n, err +} + +func TestFromIOWriterPartialWrite(t *testing.T) { + srcs := makeBlocks([]byte("foo"), []byte("bar")) + var dst bytes.Buffer + w := FromIOWriter{&limitedWriter{&dst, 0, 4}} + n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) + // FromIOWriter should stop after the limitedWriter returns (1, nil) for a + // 3-byte write. + if wantN := uint64(4); n != wantN || err != nil { + t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) { + t.Errorf("dst: got %q, wanted %q", got, want) + } +} + +type singleByteWriter struct { + Writer io.Writer +} + +func (w singleByteWriter) Write(src []byte) (int, error) { + if len(src) == 0 { + return w.Writer.Write(src) + } + return w.Writer.Write(src[:1]) +} + +func TestSingleByteWriter(t *testing.T) { + srcs := makeBlocks([]byte("foo"), []byte("bar")) + var dst bytes.Buffer + w := FromIOWriter{singleByteWriter{&dst}} + n, err := w.WriteFromBlocks(BlockSeqFromSlice(srcs)) + // FromIOWriter should stop after the singleByteWriter returns (1, nil) + // for a 3-byte write. + if wantN := uint64(1); n != wantN || err != nil { + t.Errorf("WriteFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if got, want := dst.Bytes(), []byte("f"); !bytes.Equal(got, want) { + t.Errorf("dst: got %q, wanted %q", got, want) + } +} + +func TestWriteFullToBlocks(t *testing.T) { + srcs := makeBlocks([]byte("foo"), []byte("bar")) + var dst bytes.Buffer + w := FromIOWriter{singleByteWriter{&dst}} + n, err := WriteFullFromBlocks(w, BlockSeqFromSlice(srcs)) + // WriteFullToBlocks should call into FromIOWriter => singleByteWriter + // repeatedly until srcs is exhausted. + if wantN := uint64(6); n != wantN || err != nil { + t.Errorf("WriteFullFromBlocks: got (%v, %v), wanted (%v, nil)", n, err, wantN) + } + if got, want := dst.Bytes(), []byte("foobar"); !bytes.Equal(got, want) { + t.Errorf("dst: got %q, wanted %q", got, want) + } +} diff --git a/pkg/sentry/safemem/safemem.go b/pkg/sentry/safemem/safemem.go new file mode 100644 index 000000000..2f8002004 --- /dev/null +++ b/pkg/sentry/safemem/safemem.go @@ -0,0 +1,16 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package safemem provides the Block and BlockSeq types. +package safemem diff --git a/pkg/sentry/safemem/seq_test.go b/pkg/sentry/safemem/seq_test.go new file mode 100644 index 000000000..3e83b3851 --- /dev/null +++ b/pkg/sentry/safemem/seq_test.go @@ -0,0 +1,196 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package safemem + +import ( + "bytes" + "reflect" + "testing" +) + +type blockSeqTest struct { + desc string + + pieces []string + haveOffset bool + offset uint64 + haveLimit bool + limit uint64 + + want string +} + +func (t blockSeqTest) NonEmptyByteSlices() [][]byte { + // t is a value, so we can mutate it freely. + slices := make([][]byte, 0, len(t.pieces)) + for _, str := range t.pieces { + if t.haveOffset { + strOff := t.offset + if strOff > uint64(len(str)) { + strOff = uint64(len(str)) + } + str = str[strOff:] + t.offset -= strOff + } + if t.haveLimit { + strLim := t.limit + if strLim > uint64(len(str)) { + strLim = uint64(len(str)) + } + str = str[:strLim] + t.limit -= strLim + } + if len(str) != 0 { + slices = append(slices, []byte(str)) + } + } + return slices +} + +func (t blockSeqTest) BlockSeq() BlockSeq { + blocks := make([]Block, 0, len(t.pieces)) + for _, str := range t.pieces { + blocks = append(blocks, BlockFromSafeSlice([]byte(str))) + } + bs := BlockSeqFromSlice(blocks) + if t.haveOffset { + bs = bs.DropFirst64(t.offset) + } + if t.haveLimit { + bs = bs.TakeFirst64(t.limit) + } + return bs +} + +var blockSeqTests = []blockSeqTest{ + { + desc: "Empty sequence", + }, + { + desc: "Sequence of length 1", + pieces: []string{"foobar"}, + want: "foobar", + }, + { + desc: "Sequence of length 2", + pieces: []string{"foo", "bar"}, + want: "foobar", + }, + { + desc: "Empty Blocks", + pieces: []string{"", "foo", "", "", "bar", ""}, + want: "foobar", + }, + { + desc: "Sequence with non-zero offset", + pieces: []string{"foo", "bar"}, + haveOffset: true, + offset: 2, + want: "obar", + }, + { + desc: "Sequence with non-maximal limit", + pieces: []string{"foo", "bar"}, + haveLimit: true, + limit: 5, + want: "fooba", + }, + { + desc: "Sequence with offset and limit", + pieces: []string{"foo", "bar"}, + haveOffset: true, + offset: 2, + haveLimit: true, + limit: 3, + want: "oba", + }, +} + +func TestBlockSeqNumBytes(t *testing.T) { + for _, test := range blockSeqTests { + t.Run(test.desc, func(t *testing.T) { + if got, want := test.BlockSeq().NumBytes(), uint64(len(test.want)); got != want { + t.Errorf("NumBytes: got %d, wanted %d", got, want) + } + }) + } +} + +func TestBlockSeqIterBlocks(t *testing.T) { + // Tests BlockSeq iteration using Head/Tail. + for _, test := range blockSeqTests { + t.Run(test.desc, func(t *testing.T) { + srcs := test.BlockSeq() + // "Note that a non-nil empty slice and a nil slice ... are not + // deeply equal." - reflect + slices := make([][]byte, 0, 0) + for !srcs.IsEmpty() { + src := srcs.Head() + slices = append(slices, src.ToSlice()) + nextSrcs := srcs.Tail() + if got, want := nextSrcs.NumBytes(), srcs.NumBytes()-uint64(src.Len()); got != want { + t.Fatalf("%v.Tail(): got %v (%d bytes), wanted %d bytes", srcs, nextSrcs, got, want) + } + srcs = nextSrcs + } + if wantSlices := test.NonEmptyByteSlices(); !reflect.DeepEqual(slices, wantSlices) { + t.Errorf("Accumulated slices: got %v, wanted %v", slices, wantSlices) + } + }) + } +} + +func TestBlockSeqIterBytes(t *testing.T) { + // Tests BlockSeq iteration using Head/DropFirst. + for _, test := range blockSeqTests { + t.Run(test.desc, func(t *testing.T) { + srcs := test.BlockSeq() + var dst bytes.Buffer + for !srcs.IsEmpty() { + src := srcs.Head() + var b [1]byte + n, err := Copy(BlockFromSafeSlice(b[:]), src) + if n != 1 || err != nil { + t.Fatalf("Copy: got (%v, %v), wanted (1, nil)", n, err) + } + dst.WriteByte(b[0]) + nextSrcs := srcs.DropFirst(1) + if got, want := nextSrcs.NumBytes(), srcs.NumBytes()-1; got != want { + t.Fatalf("%v.DropFirst(1): got %v (%d bytes), wanted %d bytes", srcs, nextSrcs, got, want) + } + srcs = nextSrcs + } + if got := string(dst.Bytes()); got != test.want { + t.Errorf("Copied string: got %q, wanted %q", got, test.want) + } + }) + } +} + +func TestBlockSeqDropBeyondLimit(t *testing.T) { + blocks := []Block{BlockFromSafeSlice([]byte("123")), BlockFromSafeSlice([]byte("4"))} + bs := BlockSeqFromSlice(blocks) + if got, want := bs.NumBytes(), uint64(4); got != want { + t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) + } + bs = bs.TakeFirst(1) + if got, want := bs.NumBytes(), uint64(1); got != want { + t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) + } + bs = bs.DropFirst(2) + if got, want := bs.NumBytes(), uint64(0); got != want { + t.Errorf("%v.NumBytes(): got %d, wanted %d", bs, got, want) + } +} diff --git a/pkg/sentry/safemem/seq_unsafe.go b/pkg/sentry/safemem/seq_unsafe.go new file mode 100644 index 000000000..e0d29a0b3 --- /dev/null +++ b/pkg/sentry/safemem/seq_unsafe.go @@ -0,0 +1,299 @@ +// Copyright 2018 Google Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package safemem + +import ( + "bytes" + "fmt" + "reflect" + "unsafe" +) + +// A BlockSeq represents a sequence of Blocks, each of which has non-zero +// length. +// +// BlockSeqs are immutable and may be copied by value. The zero value of +// BlockSeq represents an empty sequence. +type BlockSeq struct { + // If length is 0, then the BlockSeq is empty. Invariants: data == 0; + // offset == 0; limit == 0. + // + // If length is -1, then the BlockSeq represents the single Block{data, + // limit, false}. Invariants: offset == 0; limit > 0; limit does not + // overflow the range of an int. + // + // If length is -2, then the BlockSeq represents the single Block{data, + // limit, true}. Invariants: offset == 0; limit > 0; limit does not + // overflow the range of an int. + // + // Otherwise, length >= 2, and the BlockSeq represents the `length` Blocks + // in the array of Blocks starting at address `data`, starting at `offset` + // bytes into the first Block and limited to the following `limit` bytes. + // Invariants: data != 0; offset < len(data[0]); limit > 0; offset+limit <= + // the combined length of all Blocks in the array; the first Block in the + // array has non-zero length. + // + // length is never 1; sequences consisting of a single Block are always + // stored inline (with length < 0). + data unsafe.Pointer + length int + offset int + limit uint64 +} + +// BlockSeqOf returns a BlockSeq representing the single Block b. +func BlockSeqOf(b Block) BlockSeq { + bs := BlockSeq{ + data: b.start, + length: -1, + limit: uint64(b.length), + } + if b.needSafecopy { + bs.length = -2 + } + return bs +} + +// BlockSeqFromSlice returns a BlockSeq representing all Blocks in slice. +// If slice contains Blocks with zero length, BlockSeq will skip them during +// iteration. +// +// Whether the returned BlockSeq shares memory with slice is unspecified; +// clients should avoid mutating slices passed to BlockSeqFromSlice. +// +// Preconditions: The combined length of all Blocks in slice <= math.MaxUint64. +func BlockSeqFromSlice(slice []Block) BlockSeq { + slice = skipEmpty(slice) + var limit uint64 + for _, b := range slice { + sum := limit + uint64(b.Len()) + if sum < limit { + panic("BlockSeq length overflows uint64") + } + limit = sum + } + return blockSeqFromSliceLimited(slice, limit) +} + +// Preconditions: The combined length of all Blocks in slice <= limit. If +// len(slice) != 0, the first Block in slice has non-zero length, and limit > +// 0. +func blockSeqFromSliceLimited(slice []Block, limit uint64) BlockSeq { + switch len(slice) { + case 0: + return BlockSeq{} + case 1: + return BlockSeqOf(slice[0].TakeFirst64(limit)) + default: + return BlockSeq{ + data: unsafe.Pointer(&slice[0]), + length: len(slice), + limit: limit, + } + } +} + +func skipEmpty(slice []Block) []Block { + for i, b := range slice { + if b.Len() != 0 { + return slice[i:] + } + } + return nil +} + +// IsEmpty returns true if bs contains no Blocks. +// +// Invariants: bs.IsEmpty() == (bs.NumBlocks() == 0) == (bs.NumBytes() == 0). +// (Of these, prefer to use bs.IsEmpty().) +func (bs BlockSeq) IsEmpty() bool { + return bs.length == 0 +} + +// NumBlocks returns the number of Blocks in bs. +func (bs BlockSeq) NumBlocks() int { + // In general, we have to count: if bs represents a windowed slice then the + // slice may contain Blocks with zero length, and bs.length may be larger + // than the actual number of Blocks due to bs.limit. + var n int + for !bs.IsEmpty() { + n++ + bs = bs.Tail() + } + return n +} + +// NumBytes returns the sum of Block.Len() for all Blocks in bs. +func (bs BlockSeq) NumBytes() uint64 { + return bs.limit +} + +// Head returns the first Block in bs. +// +// Preconditions: !bs.IsEmpty(). +func (bs BlockSeq) Head() Block { + if bs.length == 0 { + panic("empty BlockSeq") + } + if bs.length < 0 { + return bs.internalBlock() + } + return (*Block)(bs.data).DropFirst(bs.offset).TakeFirst64(bs.limit) +} + +// Preconditions: bs.length < 0. +func (bs BlockSeq) internalBlock() Block { + return Block{ + start: bs.data, + length: int(bs.limit), + needSafecopy: bs.length == -2, + } +} + +// Tail returns a BlockSeq consisting of all Blocks in bs after the first. +// +// Preconditions: !bs.IsEmpty(). +func (bs BlockSeq) Tail() BlockSeq { + if bs.length == 0 { + panic("empty BlockSeq") + } + if bs.length < 0 { + return BlockSeq{} + } + head := (*Block)(bs.data).DropFirst(bs.offset) + headLen := uint64(head.Len()) + if headLen >= bs.limit { + // The head Block exhausts the limit, so the tail is empty. + return BlockSeq{} + } + var extSlice []Block + extSliceHdr := (*reflect.SliceHeader)(unsafe.Pointer(&extSlice)) + extSliceHdr.Data = uintptr(bs.data) + extSliceHdr.Len = bs.length + extSliceHdr.Cap = bs.length + tailSlice := skipEmpty(extSlice[1:]) + tailLimit := bs.limit - headLen + return blockSeqFromSliceLimited(tailSlice, tailLimit) +} + +// DropFirst returns a BlockSeq equivalent to bs, but with the first n bytes +// omitted. If n > bs.NumBytes(), DropFirst returns an empty BlockSeq. +// +// Preconditions: n >= 0. +func (bs BlockSeq) DropFirst(n int) BlockSeq { + if n < 0 { + panic(fmt.Sprintf("invalid n: %d", n)) + } + return bs.DropFirst64(uint64(n)) +} + +// DropFirst64 is equivalent to DropFirst but takes an uint64. +func (bs BlockSeq) DropFirst64(n uint64) BlockSeq { + if n >= bs.limit { + return BlockSeq{} + } + for { + // Calling bs.Head() here is surprisingly expensive, so inline getting + // the head's length. + var headLen uint64 + if bs.length < 0 { + headLen = bs.limit + } else { + headLen = uint64((*Block)(bs.data).Len() - bs.offset) + } + if n < headLen { + // Dropping ends partway through the head Block. + if bs.length < 0 { + return BlockSeqOf(bs.internalBlock().DropFirst64(n)) + } + bs.offset += int(n) + bs.limit -= n + return bs + } + n -= headLen + bs = bs.Tail() + } +} + +// TakeFirst returns a BlockSeq equivalent to the first n bytes of bs. If n > +// bs.NumBytes(), TakeFirst returns a BlockSeq equivalent to bs. +// +// Preconditions: n >= 0. +func (bs BlockSeq) TakeFirst(n int) BlockSeq { + if n < 0 { + panic(fmt.Sprintf("invalid n: %d", n)) + } + return bs.TakeFirst64(uint64(n)) +} + +// TakeFirst64 is equivalent to TakeFirst but takes a uint64. +func (bs BlockSeq) TakeFirst64(n uint64) BlockSeq { + if n == 0 { + return BlockSeq{} + } + if bs.limit > n { + bs.limit = n + } + return bs +} + +// String implements fmt.Stringer.String. +func (bs BlockSeq) String() string { + var buf bytes.Buffer + buf.WriteByte('[') + var sep string + for !bs.IsEmpty() { + buf.WriteString(sep) + sep = " " + buf.WriteString(bs.Head().String()) + bs = bs.Tail() + } + buf.WriteByte(']') + return buf.String() +} + +// CopySeq copies srcs.NumBytes() or dsts.NumBytes() bytes, whichever is less, +// from srcs to dsts and returns the number of bytes copied. +// +// If srcs and dsts overlap, the data stored in dsts is unspecified. +func CopySeq(dsts, srcs BlockSeq) (uint64, error) { + var done uint64 + for !dsts.IsEmpty() && !srcs.IsEmpty() { + dst := dsts.Head() + src := srcs.Head() + n, err := Copy(dst, src) + done += uint64(n) + if err != nil { + return done, err + } + dsts = dsts.DropFirst(n) + srcs = srcs.DropFirst(n) + } + return done, nil +} + +// ZeroSeq sets all bytes in dsts to 0 and returns the number of bytes zeroed. +func ZeroSeq(dsts BlockSeq) (uint64, error) { + var done uint64 + for !dsts.IsEmpty() { + n, err := Zero(dsts.Head()) + done += uint64(n) + if err != nil { + return done, err + } + dsts = dsts.DropFirst(n) + } + return done, nil +} |