summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/sentry/usermem/usermem.go36
-rw-r--r--pkg/sentry/usermem/usermem_test.go15
2 files changed, 37 insertions, 14 deletions
diff --git a/pkg/sentry/usermem/usermem.go b/pkg/sentry/usermem/usermem.go
index 31e4d6ada..9dde327a2 100644
--- a/pkg/sentry/usermem/usermem.go
+++ b/pkg/sentry/usermem/usermem.go
@@ -222,9 +222,11 @@ func CopyObjectIn(ctx context.Context, uio IO, addr Addr, dst interface{}, opts
return int(r.Addr - addr), nil
}
-// copyStringIncrement is the maximum number of bytes that are copied from
-// virtual memory at a time by CopyStringIn.
-const copyStringIncrement = 64
+// CopyStringIn tuning parameters, defined outside that function for tests.
+const (
+ copyStringIncrement = 64
+ copyStringMaxInitBufLen = 256
+)
// CopyStringIn copies a NUL-terminated string of unknown length from the
// memory mapped at addr in uio and returns it as a string (not including the
@@ -234,31 +236,38 @@ const copyStringIncrement = 64
//
// Preconditions: As for IO.CopyFromUser. maxlen >= 0.
func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpts) (string, error) {
- buf := make([]byte, maxlen)
+ initLen := maxlen
+ if initLen > copyStringMaxInitBufLen {
+ initLen = copyStringMaxInitBufLen
+ }
+ buf := make([]byte, initLen)
var done int
for done < maxlen {
- start, ok := addr.AddLength(uint64(done))
- if !ok {
- // Last page of kernel memory. The application can't use this
- // anyway.
- return stringFromImmutableBytes(buf[:done]), syserror.EFAULT
- }
// Read up to copyStringIncrement bytes at a time.
readlen := copyStringIncrement
if readlen > maxlen-done {
readlen = maxlen - done
}
- end, ok := start.AddLength(uint64(readlen))
+ end, ok := addr.AddLength(uint64(readlen))
if !ok {
return stringFromImmutableBytes(buf[:done]), syserror.EFAULT
}
// Shorten the read to avoid crossing page boundaries, since faulting
// in a page unnecessarily is expensive. This also ensures that partial
// copies up to the end of application-mappable memory succeed.
- if start.RoundDown() != end.RoundDown() {
+ if addr.RoundDown() != end.RoundDown() {
end = end.RoundDown()
+ readlen = int(end - addr)
+ }
+ // Ensure that our buffer is large enough to accommodate the read.
+ if done+readlen > len(buf) {
+ newBufLen := len(buf) * 2
+ if newBufLen > maxlen {
+ newBufLen = maxlen
+ }
+ buf = append(buf, make([]byte, newBufLen-len(buf))...)
}
- n, err := uio.CopyIn(ctx, start, buf[done:done+int(end-start)], opts)
+ n, err := uio.CopyIn(ctx, addr, buf[done:done+readlen], opts)
// Look for the terminating zero byte, which may have occurred before
// hitting err.
for i, c := range buf[done : done+n] {
@@ -270,6 +279,7 @@ func CopyStringIn(ctx context.Context, uio IO, addr Addr, maxlen int, opts IOOpt
if err != nil {
return stringFromImmutableBytes(buf[:done]), err
}
+ addr = end
}
return stringFromImmutableBytes(buf), syserror.ENAMETOOLONG
}
diff --git a/pkg/sentry/usermem/usermem_test.go b/pkg/sentry/usermem/usermem_test.go
index 4a07118b7..575e5039d 100644
--- a/pkg/sentry/usermem/usermem_test.go
+++ b/pkg/sentry/usermem/usermem_test.go
@@ -192,6 +192,7 @@ func TestCopyObject(t *testing.T) {
}
func TestCopyStringInShort(t *testing.T) {
+ // Tests for string length <= copyStringIncrement.
want := strings.Repeat("A", copyStringIncrement-2)
mem := want + "\x00"
if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil {
@@ -200,13 +201,25 @@ func TestCopyStringInShort(t *testing.T) {
}
func TestCopyStringInLong(t *testing.T) {
- want := strings.Repeat("A", copyStringIncrement+1)
+ // Tests for copyStringIncrement < string length <= copyStringMaxInitBufLen
+ // (requiring multiple calls to IO.CopyIn()).
+ want := strings.Repeat("A", copyStringIncrement*3/4) + strings.Repeat("B", copyStringIncrement*3/4)
mem := want + "\x00"
if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringIncrement, IOOpts{}); got != want || err != nil {
t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want)
}
}
+func TestCopyStringInVeryLong(t *testing.T) {
+ // Tests for string length > copyStringMaxInitBufLen (requiring buffer
+ // reallocation).
+ want := strings.Repeat("A", copyStringMaxInitBufLen*3/4) + strings.Repeat("B", copyStringMaxInitBufLen*3/4)
+ mem := want + "\x00"
+ if got, err := CopyStringIn(newContext(), newBytesIOString(mem), 0, 2*copyStringMaxInitBufLen, IOOpts{}); got != want || err != nil {
+ t.Errorf("CopyStringIn: got (%q, %v), wanted (%q, nil)", got, err, want)
+ }
+}
+
func TestCopyStringInNoTerminatingZeroByte(t *testing.T) {
want := strings.Repeat("A", copyStringIncrement-1)
got, err := CopyStringIn(newContext(), newBytesIOString(want), 0, 2*copyStringIncrement, IOOpts{})