summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/usermem
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/usermem')
-rw-r--r--pkg/sentry/usermem/BUILD57
-rw-r--r--pkg/sentry/usermem/README.md31
-rwxr-xr-xpkg/sentry/usermem/addr_range.go62
-rw-r--r--pkg/sentry/usermem/addr_range_seq_test.go197
-rwxr-xr-xpkg/sentry/usermem/usermem_state_autogen.go49
-rw-r--r--pkg/sentry/usermem/usermem_test.go424
6 files changed, 111 insertions, 709 deletions
diff --git a/pkg/sentry/usermem/BUILD b/pkg/sentry/usermem/BUILD
deleted file mode 100644
index e38b31b08..000000000
--- a/pkg/sentry/usermem/BUILD
+++ /dev/null
@@ -1,57 +0,0 @@
-package(licenses = ["notice"])
-
-load("//tools/go_generics:defs.bzl", "go_template_instance")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
-
-go_template_instance(
- name = "addr_range",
- out = "addr_range.go",
- package = "usermem",
- prefix = "Addr",
- template = "//pkg/segment:generic_range",
- types = {
- "T": "Addr",
- },
-)
-
-go_library(
- name = "usermem",
- srcs = [
- "access_type.go",
- "addr.go",
- "addr_range.go",
- "addr_range_seq_unsafe.go",
- "bytes_io.go",
- "bytes_io_unsafe.go",
- "usermem.go",
- "usermem_arm64.go",
- "usermem_unsafe.go",
- "usermem_x86.go",
- ],
- importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/usermem",
- visibility = ["//pkg/sentry:internal"],
- deps = [
- "//pkg/atomicbitops",
- "//pkg/binary",
- "//pkg/log",
- "//pkg/sentry/context",
- "//pkg/sentry/safemem",
- "//pkg/syserror",
- "//pkg/tcpip/buffer",
- ],
-)
-
-go_test(
- name = "usermem_test",
- size = "small",
- srcs = [
- "addr_range_seq_test.go",
- "usermem_test.go",
- ],
- embed = [":usermem"],
- deps = [
- "//pkg/sentry/context",
- "//pkg/sentry/safemem",
- "//pkg/syserror",
- ],
-)
diff --git a/pkg/sentry/usermem/README.md b/pkg/sentry/usermem/README.md
deleted file mode 100644
index f6d2137eb..000000000
--- a/pkg/sentry/usermem/README.md
+++ /dev/null
@@ -1,31 +0,0 @@
-This package defines primitives for sentry access to application memory.
-
-Major types:
-
-- The `IO` interface represents a virtual address space and provides I/O
- methods on that address space. `IO` is the lowest-level primitive. The
- primary implementation of the `IO` interface is `mm.MemoryManager`.
-
-- `IOSequence` represents a collection of individually-contiguous address
- ranges in a `IO` that is operated on sequentially, analogous to Linux's
- `struct iov_iter`.
-
-Major usage patterns:
-
-- Access to a task's virtual memory, subject to the application's memory
- protections and while running on that task's goroutine, from a context that
- is at or above the level of the `kernel` package (e.g. most syscall
- implementations in `syscalls/linux`); use the `kernel.Task.Copy*` wrappers
- defined in `kernel/task_usermem.go`.
-
-- Access to a task's virtual memory, from a context that is at or above the
- level of the `kernel` package, but where any of the above constraints does
- not hold (e.g. `PTRACE_POKEDATA`, which ignores application memory
- protections); obtain the task's `mm.MemoryManager` by calling
- `kernel.Task.MemoryManager`, and call its `IO` methods directly.
-
-- Access to a task's virtual memory, from a context that is below the level of
- the `kernel` package (e.g. filesystem I/O); clients must pass I/O arguments
- from higher layers, usually in the form of an `IOSequence`. The
- `kernel.Task.SingleIOSequence` and `kernel.Task.IovecsIOSequence` functions
- in `kernel/task_usermem.go` are convenience functions for doing so.
diff --git a/pkg/sentry/usermem/addr_range.go b/pkg/sentry/usermem/addr_range.go
new file mode 100755
index 000000000..152ed1434
--- /dev/null
+++ b/pkg/sentry/usermem/addr_range.go
@@ -0,0 +1,62 @@
+package usermem
+
+// A Range represents a contiguous range of T.
+//
+// +stateify savable
+type AddrRange struct {
+ // Start is the inclusive start of the range.
+ Start Addr
+
+ // End is the exclusive end of the range.
+ End Addr
+}
+
+// WellFormed returns true if r.Start <= r.End. All other methods on a Range
+// require that the Range is well-formed.
+func (r AddrRange) WellFormed() bool {
+ return r.Start <= r.End
+}
+
+// Length returns the length of the range.
+func (r AddrRange) Length() Addr {
+ return r.End - r.Start
+}
+
+// Contains returns true if r contains x.
+func (r AddrRange) Contains(x Addr) bool {
+ return r.Start <= x && x < r.End
+}
+
+// Overlaps returns true if r and r2 overlap.
+func (r AddrRange) Overlaps(r2 AddrRange) bool {
+ return r.Start < r2.End && r2.Start < r.End
+}
+
+// IsSupersetOf returns true if r is a superset of r2; that is, the range r2 is
+// contained within r.
+func (r AddrRange) IsSupersetOf(r2 AddrRange) bool {
+ return r.Start <= r2.Start && r.End >= r2.End
+}
+
+// Intersect returns a range consisting of the intersection between r and r2.
+// If r and r2 do not overlap, Intersect returns a range with unspecified
+// bounds, but for which Length() == 0.
+func (r AddrRange) Intersect(r2 AddrRange) AddrRange {
+ if r.Start < r2.Start {
+ r.Start = r2.Start
+ }
+ if r.End > r2.End {
+ r.End = r2.End
+ }
+ if r.End < r.Start {
+ r.End = r.Start
+ }
+ return r
+}
+
+// CanSplitAt returns true if it is legal to split a segment spanning the range
+// r at x; that is, splitting at x would produce two ranges, both of which have
+// non-zero length.
+func (r AddrRange) CanSplitAt(x Addr) bool {
+ return r.Contains(x) && r.Start < x
+}
diff --git a/pkg/sentry/usermem/addr_range_seq_test.go b/pkg/sentry/usermem/addr_range_seq_test.go
deleted file mode 100644
index 82f735026..000000000
--- a/pkg/sentry/usermem/addr_range_seq_test.go
+++ /dev/null
@@ -1,197 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package usermem
-
-import (
- "testing"
-)
-
-var addrRangeSeqTests = []struct {
- desc string
- ranges []AddrRange
-}{
- {
- desc: "Empty sequence",
- },
- {
- desc: "Single empty AddrRange",
- ranges: []AddrRange{
- {0x10, 0x10},
- },
- },
- {
- desc: "Single non-empty AddrRange of length 1",
- ranges: []AddrRange{
- {0x10, 0x11},
- },
- },
- {
- desc: "Single non-empty AddrRange of length 2",
- ranges: []AddrRange{
- {0x10, 0x12},
- },
- },
- {
- desc: "Multiple non-empty AddrRanges",
- ranges: []AddrRange{
- {0x10, 0x11},
- {0x20, 0x22},
- },
- },
- {
- desc: "Multiple AddrRanges including empty AddrRanges",
- ranges: []AddrRange{
- {0x10, 0x10},
- {0x20, 0x20},
- {0x30, 0x33},
- {0x40, 0x44},
- {0x50, 0x50},
- {0x60, 0x60},
- {0x70, 0x77},
- {0x80, 0x88},
- {0x90, 0x90},
- {0xa0, 0xa0},
- },
- },
-}
-
-func testAddrRangeSeqEqualityWithTailIteration(t *testing.T, ars AddrRangeSeq, wantRanges []AddrRange) {
- var wantLen int64
- for _, ar := range wantRanges {
- wantLen += int64(ar.Length())
- }
-
- var i int
- for !ars.IsEmpty() {
- if gotLen := ars.NumBytes(); gotLen != wantLen {
- t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d", i, ars, gotLen, wantLen)
- }
- if gotN, wantN := ars.NumRanges(), len(wantRanges)-i; gotN != wantN {
- t.Errorf("Iteration %d: %v.NumRanges(): got %d, wanted %d", i, ars, gotN, wantN)
- }
- got := ars.Head()
- if i >= len(wantRanges) {
- t.Errorf("Iteration %d: %v.Head(): got %s, wanted <end of sequence>", i, ars, got)
- } else if want := wantRanges[i]; got != want {
- t.Errorf("Iteration %d: %v.Head(): got %s, wanted %s", i, ars, got, want)
- }
- ars = ars.Tail()
- wantLen -= int64(got.Length())
- i++
- }
- if gotLen := ars.NumBytes(); gotLen != 0 || wantLen != 0 {
- t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d (which should be 0)", i, ars, gotLen, wantLen)
- }
- if gotN := ars.NumRanges(); gotN != 0 {
- t.Errorf("Iteration %d: %v.NumRanges(): got %d, wanted 0", i, ars, gotN)
- }
-}
-
-func TestAddrRangeSeqTailIteration(t *testing.T) {
- for _, test := range addrRangeSeqTests {
- t.Run(test.desc, func(t *testing.T) {
- testAddrRangeSeqEqualityWithTailIteration(t, AddrRangeSeqFromSlice(test.ranges), test.ranges)
- })
- }
-}
-
-func TestAddrRangeSeqDropFirstEmpty(t *testing.T) {
- var ars AddrRangeSeq
- if got, want := ars.DropFirst(1), ars; got != want {
- t.Errorf("%v.DropFirst(1): got %v, wanted %v", ars, got, want)
- }
-}
-
-func TestAddrRangeSeqDropSingleByteIteration(t *testing.T) {
- // Tests AddrRangeSeq iteration using Head/DropFirst, simulating
- // I/O-per-AddrRange.
- for _, test := range addrRangeSeqTests {
- t.Run(test.desc, func(t *testing.T) {
- // Figure out what AddrRanges we expect to see.
- var wantLen int64
- var wantRanges []AddrRange
- for _, ar := range test.ranges {
- wantLen += int64(ar.Length())
- wantRanges = append(wantRanges, ar)
- if ar.Length() == 0 {
- // We "do" 0 bytes of I/O and then call DropFirst(0),
- // advancing to the next AddrRange.
- continue
- }
- // Otherwise we "do" 1 byte of I/O and then call DropFirst(1),
- // advancing the AddrRange by 1 byte, or to the next AddrRange
- // if this one is exhausted.
- for ar.Start++; ar.Length() != 0; ar.Start++ {
- wantRanges = append(wantRanges, ar)
- }
- }
- t.Logf("Expected AddrRanges: %s (%d bytes)", wantRanges, wantLen)
-
- ars := AddrRangeSeqFromSlice(test.ranges)
- var i int
- for !ars.IsEmpty() {
- if gotLen := ars.NumBytes(); gotLen != wantLen {
- t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d", i, ars, gotLen, wantLen)
- }
- got := ars.Head()
- if i >= len(wantRanges) {
- t.Errorf("Iteration %d: %v.Head(): got %s, wanted <end of sequence>", i, ars, got)
- } else if want := wantRanges[i]; got != want {
- t.Errorf("Iteration %d: %v.Head(): got %s, wanted %s", i, ars, got, want)
- }
- if got.Length() == 0 {
- ars = ars.DropFirst(0)
- } else {
- ars = ars.DropFirst(1)
- wantLen--
- }
- i++
- }
- if gotLen := ars.NumBytes(); gotLen != 0 || wantLen != 0 {
- t.Errorf("Iteration %d: %v.NumBytes(): got %d, wanted %d (which should be 0)", i, ars, gotLen, wantLen)
- }
- })
- }
-}
-
-func TestAddrRangeSeqTakeFirstEmpty(t *testing.T) {
- var ars AddrRangeSeq
- if got, want := ars.TakeFirst(1), ars; got != want {
- t.Errorf("%v.TakeFirst(1): got %v, wanted %v", ars, got, want)
- }
-}
-
-func TestAddrRangeSeqTakeFirst(t *testing.T) {
- ranges := []AddrRange{
- {0x10, 0x11},
- {0x20, 0x22},
- {0x30, 0x30},
- {0x40, 0x44},
- {0x50, 0x55},
- {0x60, 0x60},
- {0x70, 0x77},
- }
- ars := AddrRangeSeqFromSlice(ranges).TakeFirst(5)
- want := []AddrRange{
- {0x10, 0x11}, // +1 byte (total 1 byte), not truncated
- {0x20, 0x22}, // +2 bytes (total 3 bytes), not truncated
- {0x30, 0x30}, // +0 bytes (total 3 bytes), no change
- {0x40, 0x42}, // +2 bytes (total 5 bytes), partially truncated
- {0x50, 0x50}, // +0 bytes (total 5 bytes), fully truncated
- {0x60, 0x60}, // +0 bytes (total 5 bytes), "fully truncated" (no change)
- {0x70, 0x70}, // +0 bytes (total 5 bytes), fully truncated
- }
- testAddrRangeSeqEqualityWithTailIteration(t, ars, want)
-}
diff --git a/pkg/sentry/usermem/usermem_state_autogen.go b/pkg/sentry/usermem/usermem_state_autogen.go
new file mode 100755
index 000000000..bc728eab3
--- /dev/null
+++ b/pkg/sentry/usermem/usermem_state_autogen.go
@@ -0,0 +1,49 @@
+// automatically generated by stateify.
+
+package usermem
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/state"
+)
+
+func (x *AccessType) beforeSave() {}
+func (x *AccessType) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Read", &x.Read)
+ m.Save("Write", &x.Write)
+ m.Save("Execute", &x.Execute)
+}
+
+func (x *AccessType) afterLoad() {}
+func (x *AccessType) load(m state.Map) {
+ m.Load("Read", &x.Read)
+ m.Load("Write", &x.Write)
+ m.Load("Execute", &x.Execute)
+}
+
+func (x *Addr) save(m state.Map) {
+ m.SaveValue("", (uintptr)(*x))
+}
+
+func (x *Addr) load(m state.Map) {
+ m.LoadValue("", new(uintptr), func(y interface{}) { *x = (Addr)(y.(uintptr)) })
+}
+
+func (x *AddrRange) beforeSave() {}
+func (x *AddrRange) save(m state.Map) {
+ x.beforeSave()
+ m.Save("Start", &x.Start)
+ m.Save("End", &x.End)
+}
+
+func (x *AddrRange) afterLoad() {}
+func (x *AddrRange) load(m state.Map) {
+ m.Load("Start", &x.Start)
+ m.Load("End", &x.End)
+}
+
+func init() {
+ state.Register("usermem.AccessType", (*AccessType)(nil), state.Fns{Save: (*AccessType).save, Load: (*AccessType).load})
+ state.Register("usermem.Addr", (*Addr)(nil), state.Fns{Save: (*Addr).save, Load: (*Addr).load})
+ state.Register("usermem.AddrRange", (*AddrRange)(nil), state.Fns{Save: (*AddrRange).save, Load: (*AddrRange).load})
+}
diff --git a/pkg/sentry/usermem/usermem_test.go b/pkg/sentry/usermem/usermem_test.go
deleted file mode 100644
index 575e5039d..000000000
--- a/pkg/sentry/usermem/usermem_test.go
+++ /dev/null
@@ -1,424 +0,0 @@
-// Copyright 2018 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-package usermem
-
-import (
- "bytes"
- "encoding/binary"
- "fmt"
- "reflect"
- "strings"
- "testing"
-
- "gvisor.googlesource.com/gvisor/pkg/sentry/context"
- "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
- "gvisor.googlesource.com/gvisor/pkg/syserror"
-)
-
-// newContext returns a context.Context that we can use in these tests (we
-// can't use contexttest because it depends on usermem).
-func newContext() context.Context {
- return context.Background()
-}
-
-func newBytesIOString(s string) *BytesIO {
- return &BytesIO{[]byte(s)}
-}
-
-func TestBytesIOCopyOutSuccess(t *testing.T) {
- b := newBytesIOString("ABCDE")
- n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{})
- if wantN := 3; n != wantN || err != nil {
- t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- if got, want := b.Bytes, []byte("AfooE"); !bytes.Equal(got, want) {
- t.Errorf("Bytes: got %q, wanted %q", got, want)
- }
-}
-
-func TestBytesIOCopyOutFailure(t *testing.T) {
- b := newBytesIOString("ABC")
- n, err := b.CopyOut(newContext(), 1, []byte("foo"), IOOpts{})
- if wantN, wantErr := 2, syserror.EFAULT; n != wantN || err != wantErr {
- t.Errorf("CopyOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
- }
- if got, want := b.Bytes, []byte("Afo"); !bytes.Equal(got, want) {
- t.Errorf("Bytes: got %q, wanted %q", got, want)
- }
-}
-
-func TestBytesIOCopyInSuccess(t *testing.T) {
- b := newBytesIOString("AfooE")
- var dst [3]byte
- n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{})
- if wantN := 3; n != wantN || err != nil {
- t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) {
- t.Errorf("dst: got %q, wanted %q", got, want)
- }
-}
-
-func TestBytesIOCopyInFailure(t *testing.T) {
- b := newBytesIOString("Afo")
- var dst [3]byte
- n, err := b.CopyIn(newContext(), 1, dst[:], IOOpts{})
- if wantN, wantErr := 2, syserror.EFAULT; n != wantN || err != wantErr {
- t.Errorf("CopyIn: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
- }
- if got, want := dst[:], []byte("fo\x00"); !bytes.Equal(got, want) {
- t.Errorf("dst: got %q, wanted %q", got, want)
- }
-}
-
-func TestBytesIOZeroOutSuccess(t *testing.T) {
- b := newBytesIOString("ABCD")
- n, err := b.ZeroOut(newContext(), 1, 2, IOOpts{})
- if wantN := int64(2); n != wantN || err != nil {
- t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- if got, want := b.Bytes, []byte("A\x00\x00D"); !bytes.Equal(got, want) {
- t.Errorf("Bytes: got %q, wanted %q", got, want)
- }
-}
-
-func TestBytesIOZeroOutFailure(t *testing.T) {
- b := newBytesIOString("ABC")
- n, err := b.ZeroOut(newContext(), 1, 3, IOOpts{})
- if wantN, wantErr := int64(2), syserror.EFAULT; n != wantN || err != wantErr {
- t.Errorf("ZeroOut: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
- }
- if got, want := b.Bytes, []byte("A\x00\x00"); !bytes.Equal(got, want) {
- t.Errorf("Bytes: got %q, wanted %q", got, want)
- }
-}
-
-func TestBytesIOCopyOutFromSuccess(t *testing.T) {
- b := newBytesIOString("ABCDEFGH")
- n, err := b.CopyOutFrom(newContext(), AddrRangeSeqFromSlice([]AddrRange{
- {Start: 4, End: 7},
- {Start: 1, End: 4},
- }), safemem.FromIOReader{bytes.NewBufferString("barfoo")}, IOOpts{})
- if wantN := int64(6); n != wantN || err != nil {
- t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- if got, want := b.Bytes, []byte("AfoobarH"); !bytes.Equal(got, want) {
- t.Errorf("Bytes: got %q, wanted %q", got, want)
- }
-}
-
-func TestBytesIOCopyOutFromFailure(t *testing.T) {
- b := newBytesIOString("ABCDE")
- n, err := b.CopyOutFrom(newContext(), AddrRangeSeqFromSlice([]AddrRange{
- {Start: 1, End: 4},
- {Start: 4, End: 7},
- }), safemem.FromIOReader{bytes.NewBufferString("foobar")}, IOOpts{})
- if wantN, wantErr := int64(4), syserror.EFAULT; n != wantN || err != wantErr {
- t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
- }
- if got, want := b.Bytes, []byte("Afoob"); !bytes.Equal(got, want) {
- t.Errorf("Bytes: got %q, wanted %q", got, want)
- }
-}
-
-func TestBytesIOCopyInToSuccess(t *testing.T) {
- b := newBytesIOString("AfoobarH")
- var dst bytes.Buffer
- n, err := b.CopyInTo(newContext(), AddrRangeSeqFromSlice([]AddrRange{
- {Start: 4, End: 7},
- {Start: 1, End: 4},
- }), safemem.FromIOWriter{&dst}, IOOpts{})
- if wantN := int64(6); n != wantN || err != nil {
- t.Errorf("CopyInTo: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- if got, want := dst.Bytes(), []byte("barfoo"); !bytes.Equal(got, want) {
- t.Errorf("dst.Bytes(): got %q, wanted %q", got, want)
- }
-}
-
-func TestBytesIOCopyInToFailure(t *testing.T) {
- b := newBytesIOString("Afoob")
- var dst bytes.Buffer
- n, err := b.CopyInTo(newContext(), AddrRangeSeqFromSlice([]AddrRange{
- {Start: 1, End: 4},
- {Start: 4, End: 7},
- }), safemem.FromIOWriter{&dst}, IOOpts{})
- if wantN, wantErr := int64(4), syserror.EFAULT; n != wantN || err != wantErr {
- t.Errorf("CopyOutFrom: got (%v, %v), wanted (%v, %v)", n, err, wantN, wantErr)
- }
- if got, want := dst.Bytes(), []byte("foob"); !bytes.Equal(got, want) {
- t.Errorf("dst.Bytes(): got %q, wanted %q", got, want)
- }
-}
-
-type testStruct struct {
- Int8 int8
- Uint8 uint8
- Int16 int16
- Uint16 uint16
- Int32 int32
- Uint32 uint32
- Int64 int64
- Uint64 uint64
-}
-
-func TestCopyObject(t *testing.T) {
- wantObj := testStruct{1, 2, 3, 4, 5, 6, 7, 8}
- wantN := binary.Size(wantObj)
- b := &BytesIO{make([]byte, wantN)}
- ctx := newContext()
- if n, err := CopyObjectOut(ctx, b, 0, &wantObj, IOOpts{}); n != wantN || err != nil {
- t.Fatalf("CopyObjectOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- var gotObj testStruct
- if n, err := CopyObjectIn(ctx, b, 0, &gotObj, IOOpts{}); n != wantN || err != nil {
- t.Errorf("CopyObjectIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- if gotObj != wantObj {
- t.Errorf("CopyObject round trip: got %+v, wanted %+v", gotObj, wantObj)
- }
-}
-
-func TestCopyStringInShort(t *testing.T) {
- // Tests for string length <= copyStringIncrement.
- want := strings.Repeat("A", copyStringIncrement-2)
- mem := want + "\x00"
- if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil {
- t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want)
- }
-}
-
-func TestCopyStringInLong(t *testing.T) {
- // Tests for copyStringIncrement < string length <= copyStringMaxInitBufLen
- // (requiring multiple calls to IO.CopyIn()).
- want := strings.Repeat("A", copyStringIncrement*3/4) + strings.Repeat("B", copyStringIncrement*3/4)
- mem := want + "\x00"
- if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil {
- t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want)
- }
-}
-
-func TestCopyStringInVeryLong(t *testing.T) {
- // Tests for string length > copyStringMaxInitBufLen (requiring buffer
- // reallocation).
- want := strings.Repeat("A", copyStringMaxInitBufLen*3/4) + strings.Repeat("B", copyStringMaxInitBufLen*3/4)
- mem := want + "\x00"
- if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringMaxInitBufLen, IOOpts{}); got != want || err != nil {
- t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want)
- }
-}
-
-func TestCopyStringInNoTerminatingZeroByte(t *testing.T) {
- want := strings.Repeat("A", copyStringIncrement-1)
- got, err := CopyStringIn(newContext(), newBytesIOString(want), 0, 2*copyStringIncrement, IOOpts{})
- if wantErr := syserror.EFAULT; got != want || err != wantErr {
- t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr)
- }
-}
-
-func TestCopyStringInTruncatedByMaxlen(t *testing.T) {
- got, err := CopyStringIn(newContext(), newBytesIOString(strings.Repeat("A", 10)), 0, 5, IOOpts{})
- if want, wantErr := strings.Repeat("A", 5), syserror.ENAMETOOLONG; got != want || err != wantErr {
- t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, %v)", got, err, want, wantErr)
- }
-}
-
-func TestCopyInt32StringsInVec(t *testing.T) {
- for _, test := range []struct {
- str string
- n int
- initial []int32
- final []int32
- }{
- {
- str: "100 200",
- n: len("100 200"),
- initial: []int32{1, 2},
- final: []int32{100, 200},
- },
- {
- // Fewer values ok
- str: "100",
- n: len("100"),
- initial: []int32{1, 2},
- final: []int32{100, 2},
- },
- {
- // Extra values ok
- str: "100 200 300",
- n: len("100 200 "),
- initial: []int32{1, 2},
- final: []int32{100, 200},
- },
- {
- // Leading and trailing whitespace ok
- str: " 100\t200\n",
- n: len(" 100\t200\n"),
- initial: []int32{1, 2},
- final: []int32{100, 200},
- },
- } {
- t.Run(fmt.Sprintf("%q", test.str), func(t *testing.T) {
- src := BytesIOSequence([]byte(test.str))
- dsts := append([]int32(nil), test.initial...)
- if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); n != int64(test.n) || err != nil {
- t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (%d, nil)", n, err, test.n)
- }
- if !reflect.DeepEqual(dsts, test.final) {
- t.Errorf("dsts: got %v, wanted %v", dsts, test.final)
- }
- })
- }
-}
-
-func TestCopyInt32StringsInVecRequiresOneValidValue(t *testing.T) {
- for _, s := range []string{"", "\n", "a123"} {
- t.Run(fmt.Sprintf("%q", s), func(t *testing.T) {
- src := BytesIOSequence([]byte(s))
- initial := []int32{1, 2}
- dsts := append([]int32(nil), initial...)
- if n, err := CopyInt32StringsInVec(newContext(), src.IO, src.Addrs, dsts, src.Opts); err != syserror.EINVAL {
- t.Errorf("CopyInt32StringsInVec: got (%d, %v), wanted (_, %v)", n, err, syserror.EINVAL)
- }
- if !reflect.DeepEqual(dsts, initial) {
- t.Errorf("dsts: got %v, wanted %v", dsts, initial)
- }
- })
- }
-}
-
-func TestIOSequenceCopyOut(t *testing.T) {
- buf := []byte("ABCD")
- s := BytesIOSequence(buf)
-
- // CopyOut limited by len(src).
- n, err := s.CopyOut(newContext(), []byte("fo"))
- if wantN := 2; n != wantN || err != nil {
- t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- if want := []byte("foCD"); !bytes.Equal(buf, want) {
- t.Errorf("buf: got %q, wanted %q", buf, want)
- }
- s = s.DropFirst(2)
- if got, want := s.NumBytes(), int64(2); got != want {
- t.Errorf("NumBytes: got %v, wanted %v", got, want)
- }
-
- // CopyOut limited by s.NumBytes().
- n, err = s.CopyOut(newContext(), []byte("obar"))
- if wantN := 2; n != wantN || err != nil {
- t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- if want := []byte("foob"); !bytes.Equal(buf, want) {
- t.Errorf("buf: got %q, wanted %q", buf, want)
- }
- s = s.DropFirst(2)
- if got, want := s.NumBytes(), int64(0); got != want {
- t.Errorf("NumBytes: got %v, wanted %v", got, want)
- }
-}
-
-func TestIOSequenceCopyIn(t *testing.T) {
- s := BytesIOSequence([]byte("foob"))
- dst := []byte("ABCDEF")
-
- // CopyIn limited by len(dst).
- n, err := s.CopyIn(newContext(), dst[:2])
- if wantN := 2; n != wantN || err != nil {
- t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- if want := []byte("foCDEF"); !bytes.Equal(dst, want) {
- t.Errorf("dst: got %q, wanted %q", dst, want)
- }
- s = s.DropFirst(2)
- if got, want := s.NumBytes(), int64(2); got != want {
- t.Errorf("NumBytes: got %v, wanted %v", got, want)
- }
-
- // CopyIn limited by s.Remaining().
- n, err = s.CopyIn(newContext(), dst[2:])
- if wantN := 2; n != wantN || err != nil {
- t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- if want := []byte("foobEF"); !bytes.Equal(dst, want) {
- t.Errorf("dst: got %q, wanted %q", dst, want)
- }
- s = s.DropFirst(2)
- if got, want := s.NumBytes(), int64(0); got != want {
- t.Errorf("NumBytes: got %v, wanted %v", got, want)
- }
-}
-
-func TestIOSequenceZeroOut(t *testing.T) {
- buf := []byte("ABCD")
- s := BytesIOSequence(buf)
-
- // ZeroOut limited by toZero.
- n, err := s.ZeroOut(newContext(), 2)
- if wantN := int64(2); n != wantN || err != nil {
- t.Errorf("ZeroOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- if want := []byte("\x00\x00CD"); !bytes.Equal(buf, want) {
- t.Errorf("buf: got %q, wanted %q", buf, want)
- }
- s = s.DropFirst(2)
- if got, want := s.NumBytes(), int64(2); got != want {
- t.Errorf("NumBytes: got %v, wanted %v", got, want)
- }
-
- // ZeroOut limited by s.NumBytes().
- n, err = s.ZeroOut(newContext(), 4)
- if wantN := int64(2); n != wantN || err != nil {
- t.Errorf("CopyOut: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- if want := []byte("\x00\x00\x00\x00"); !bytes.Equal(buf, want) {
- t.Errorf("buf: got %q, wanted %q", buf, want)
- }
- s = s.DropFirst(2)
- if got, want := s.NumBytes(), int64(0); got != want {
- t.Errorf("NumBytes: got %v, wanted %v", got, want)
- }
-}
-
-func TestIOSequenceTakeFirst(t *testing.T) {
- s := BytesIOSequence([]byte("foobar"))
- if got, want := s.NumBytes(), int64(6); got != want {
- t.Errorf("NumBytes: got %v, wanted %v", got, want)
- }
-
- s = s.TakeFirst(3)
- if got, want := s.NumBytes(), int64(3); got != want {
- t.Errorf("NumBytes: got %v, wanted %v", got, want)
- }
-
- // TakeFirst(n) where n > s.NumBytes() is a no-op.
- s = s.TakeFirst(9)
- if got, want := s.NumBytes(), int64(3); got != want {
- t.Errorf("NumBytes: got %v, wanted %v", got, want)
- }
-
- var dst [3]byte
- n, err := s.CopyIn(newContext(), dst[:])
- if wantN := 3; n != wantN || err != nil {
- t.Errorf("CopyIn: got (%v, %v), wanted (%v, nil)", n, err, wantN)
- }
- if got, want := dst[:], []byte("foo"); !bytes.Equal(got, want) {
- t.Errorf("dst: got %q, wanted %q", got, want)
- }
- s = s.DropFirst(3)
- if got, want := s.NumBytes(), int64(0); got != want {
- t.Errorf("NumBytes: got %v, wanted %v", got, want)
- }
-}