summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry
diff options
context:
space:
mode:
authorAyush Ranjan <ayushranjan@google.com>2021-11-05 10:40:53 -0700
committergVisor bot <gvisor-bot@google.com>2021-11-05 10:43:49 -0700
commitce4f4283badb6b07baf9f8e6d99e7a5fd15c92db (patch)
tree848dc50da62da59dc4a5781f9eb7461c58b71512 /pkg/sentry
parent822a647018adbd994114cb0dc8932f2853b805aa (diff)
Make {Un}Marshal{Bytes/Unsafe} return remaining buffer.
Change marshal.Marshallable method signatures to return the remaining buffer. This makes it easier to implement these method manually. Without this, we would have to manually do buffer shifting which is error prone. tools/go_marshal/test:benchmark test does not show change in performance. Additionally fixed some marshalling bugs in fsimpl/fuse. Updated multiple callpoints to get rid of redundant slice indexing work and simplified code using this new signature. Updates #6450 PiperOrigin-RevId: 407857019
Diffstat (limited to 'pkg/sentry')
-rw-r--r--pkg/sentry/fsimpl/fuse/BUILD3
-rw-r--r--pkg/sentry/fsimpl/fuse/connection_test.go11
-rw-r--r--pkg/sentry/fsimpl/fuse/dev_test.go19
-rw-r--r--pkg/sentry/fsimpl/fuse/fusefs.go16
-rw-r--r--pkg/sentry/fsimpl/fuse/request_response.go4
-rw-r--r--pkg/sentry/fsimpl/fuse/utils_test.go59
-rw-r--r--pkg/sentry/loader/elf.go3
-rw-r--r--pkg/sentry/socket/control/control.go63
-rw-r--r--pkg/sentry/socket/control/control_vfs2.go5
-rw-r--r--pkg/sentry/socket/hostinet/socket.go14
-rw-r--r--pkg/sentry/socket/hostinet/stack.go6
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go5
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go5
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go5
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go7
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go2
-rw-r--r--pkg/sentry/socket/netfilter/targets.go15
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go2
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go2
-rw-r--r--pkg/sentry/socket/netlink/message_test.go45
-rw-r--r--pkg/sentry/socket/netlink/socket.go2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go14
-rw-r--r--pkg/sentry/socket/socket.go14
-rw-r--r--pkg/sentry/strace/socket.go131
24 files changed, 172 insertions, 280 deletions
diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD
index 05c4fbeb2..18497a880 100644
--- a/pkg/sentry/fsimpl/fuse/BUILD
+++ b/pkg/sentry/fsimpl/fuse/BUILD
@@ -77,8 +77,7 @@ go_test(
deps = [
"//pkg/abi/linux",
"//pkg/errors/linuxerr",
- "//pkg/hostarch",
- "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/sentry/fsimpl/testutil",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go
index 1fddd858e..d98d2832b 100644
--- a/pkg/sentry/fsimpl/fuse/connection_test.go
+++ b/pkg/sentry/fsimpl/fuse/connection_test.go
@@ -20,6 +20,7 @@ import (
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
)
@@ -69,14 +70,10 @@ func TestConnectionAbort(t *testing.T) {
t.Fatalf("newTestConnection: %v", err)
}
- testObj := &testPayload{
- data: rand.Uint32(),
- }
-
var futNormal []*futureResponse
-
+ testObj := primitive.Uint32(rand.Uint32())
for i := 0; i < int(numRequests); i++ {
- req := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj)
+ req := conn.NewRequest(creds, uint32(i), uint64(i), 0, &testObj)
fut, err := conn.callFutureLocked(task, req)
if err != nil {
t.Fatalf("callFutureLocked failed: %v", err)
@@ -102,7 +99,7 @@ func TestConnectionAbort(t *testing.T) {
}
// After abort, Call() should return directly with ENOTCONN.
- req := conn.NewRequest(creds, 0, 0, 0, testObj)
+ req := conn.NewRequest(creds, 0, 0, 0, &testObj)
_, err = conn.Call(task, req)
if !linuxerr.Equals(linuxerr.ENOTCONN, err) {
t.Fatalf("Incorrect error code received for Call() after connection aborted")
diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go
index 8951b5ba8..13b32fc7c 100644
--- a/pkg/sentry/fsimpl/fuse/dev_test.go
+++ b/pkg/sentry/fsimpl/fuse/dev_test.go
@@ -21,6 +21,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
@@ -215,11 +216,9 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con
if err != nil {
t.Fatal(err)
}
- testObj := &testPayload{
- data: rand.Uint32(),
- }
- req := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj)
+ testObj := primitive.Uint32(rand.Uint32())
+ req := conn.NewRequest(creds, pid, inode, echoTestOpcode, &testObj)
// Queue up a request.
// Analogous to Call except it doesn't block on the task.
@@ -232,7 +231,7 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con
t.Fatalf("Server responded with an error: %v", err)
}
- var respTestPayload testPayload
+ var respTestPayload primitive.Uint32
if err := resp.UnmarshalPayload(&respTestPayload); err != nil {
t.Fatalf("Unmarshalling payload error: %v", err)
}
@@ -242,8 +241,8 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con
req.hdr.Unique, resp.hdr.Unique)
}
- if respTestPayload.data != testObj.data {
- t.Fatalf("read incorrect data. Data expected: %v, but got %v", testObj.data, respTestPayload.data)
+ if respTestPayload != testObj {
+ t.Fatalf("read incorrect data. Data expected: %d, but got %d", testObj, respTestPayload)
}
}
@@ -256,8 +255,8 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F
// Create the tasks that the server will be using.
tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits())
- var readPayload testPayload
+ var readPayload primitive.Uint32
serverTask, err := testutil.CreateTask(s.Ctx, "fuse-server", tc, s.MntNs, s.Root, s.Root)
if err != nil {
t.Fatal(err)
@@ -291,8 +290,8 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F
}
var readFUSEHeaderIn linux.FUSEHeaderIn
- readFUSEHeaderIn.UnmarshalUnsafe(inBuf[:inHdrLen])
- readPayload.UnmarshalUnsafe(inBuf[inHdrLen : inHdrLen+payloadLen])
+ inBuf = readFUSEHeaderIn.UnmarshalUnsafe(inBuf)
+ readPayload.UnmarshalUnsafe(inBuf)
if readFUSEHeaderIn.Opcode != echoTestOpcode {
t.Fatalf("read incorrect data. Header: %v, Payload: %v", readFUSEHeaderIn, readPayload)
diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go
index af16098d2..00b520c31 100644
--- a/pkg/sentry/fsimpl/fuse/fusefs.go
+++ b/pkg/sentry/fsimpl/fuse/fusefs.go
@@ -489,7 +489,7 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr
// Lookup implements kernfs.Inode.Lookup.
func (i *inode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) {
- in := linux.FUSELookupIn{Name: name}
+ in := linux.FUSELookupIn{Name: linux.CString(name)}
return i.newEntry(ctx, name, 0, linux.FUSE_LOOKUP, &in)
}
@@ -520,7 +520,7 @@ func (i *inode) NewFile(ctx context.Context, name string, opts vfs.OpenOptions)
Mode: uint32(opts.Mode) | linux.S_IFREG,
Umask: uint32(kernelTask.FSContext().Umask()),
},
- Name: name,
+ Name: linux.CString(name),
}
return i.newEntry(ctx, name, linux.S_IFREG, linux.FUSE_CREATE, &in)
}
@@ -533,7 +533,7 @@ func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions)
Rdev: linux.MakeDeviceID(uint16(opts.DevMajor), opts.DevMinor),
Umask: uint32(kernel.TaskFromContext(ctx).FSContext().Umask()),
},
- Name: name,
+ Name: linux.CString(name),
}
return i.newEntry(ctx, name, opts.Mode.FileType(), linux.FUSE_MKNOD, &in)
}
@@ -541,8 +541,8 @@ func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions)
// NewSymlink implements kernfs.Inode.NewSymlink.
func (i *inode) NewSymlink(ctx context.Context, name, target string) (kernfs.Inode, error) {
in := linux.FUSESymLinkIn{
- Name: name,
- Target: target,
+ Name: linux.CString(name),
+ Target: linux.CString(target),
}
return i.newEntry(ctx, name, linux.S_IFLNK, linux.FUSE_SYMLINK, &in)
}
@@ -554,7 +554,7 @@ func (i *inode) Unlink(ctx context.Context, name string, child kernfs.Inode) err
log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID)
return linuxerr.EINVAL
}
- in := linux.FUSEUnlinkIn{Name: name}
+ in := linux.FUSEUnlinkIn{Name: linux.CString(name)}
req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in)
res, err := i.fs.conn.Call(kernelTask, req)
if err != nil {
@@ -571,7 +571,7 @@ func (i *inode) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions)
Mode: uint32(opts.Mode),
Umask: uint32(kernel.TaskFromContext(ctx).FSContext().Umask()),
},
- Name: name,
+ Name: linux.CString(name),
}
return i.newEntry(ctx, name, linux.S_IFDIR, linux.FUSE_MKDIR, &in)
}
@@ -581,7 +581,7 @@ func (i *inode) RmDir(ctx context.Context, name string, child kernfs.Inode) erro
fusefs := i.fs
task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx)
- in := linux.FUSERmDirIn{Name: name}
+ in := linux.FUSERmDirIn{Name: linux.CString(name)}
req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in)
res, err := i.fs.conn.Call(task, req)
if err != nil {
diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go
index 8a72489fa..ec76ec2a4 100644
--- a/pkg/sentry/fsimpl/fuse/request_response.go
+++ b/pkg/sentry/fsimpl/fuse/request_response.go
@@ -41,7 +41,7 @@ type fuseInitRes struct {
}
// UnmarshalBytes deserializes src to the initOut attribute in a fuseInitRes.
-func (r *fuseInitRes) UnmarshalBytes(src []byte) {
+func (r *fuseInitRes) UnmarshalBytes(src []byte) []byte {
out := &r.initOut
// Introduced before FUSE kernel version 7.13.
@@ -70,7 +70,7 @@ func (r *fuseInitRes) UnmarshalBytes(src []byte) {
out.MaxPages = uint16(hostarch.ByteOrder.Uint16(src[:2]))
src = src[2:]
}
- _ = src // Remove unused warning.
+ return src
}
// SizeBytes is the size of the payload of the FUSE_INIT response.
diff --git a/pkg/sentry/fsimpl/fuse/utils_test.go b/pkg/sentry/fsimpl/fuse/utils_test.go
index b0bab0066..8d4a2fad3 100644
--- a/pkg/sentry/fsimpl/fuse/utils_test.go
+++ b/pkg/sentry/fsimpl/fuse/utils_test.go
@@ -15,17 +15,13 @@
package fuse
import (
- "io"
"testing"
"gvisor.dev/gvisor/pkg/abi/linux"
- "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/vfs"
-
- "gvisor.dev/gvisor/pkg/hostarch"
)
func setup(t *testing.T) *testutil.System {
@@ -70,58 +66,3 @@ func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveReque
}
return fs.conn, &fuseDev.vfsfd, nil
}
-
-type testPayload struct {
- marshal.StubMarshallable
- data uint32
-}
-
-// SizeBytes implements marshal.Marshallable.SizeBytes.
-func (t *testPayload) SizeBytes() int {
- return 4
-}
-
-// MarshalBytes implements marshal.Marshallable.MarshalBytes.
-func (t *testPayload) MarshalBytes(dst []byte) {
- hostarch.ByteOrder.PutUint32(dst[:4], t.data)
-}
-
-// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.
-func (t *testPayload) UnmarshalBytes(src []byte) {
- *t = testPayload{data: hostarch.ByteOrder.Uint32(src[:4])}
-}
-
-// Packed implements marshal.Marshallable.Packed.
-func (t *testPayload) Packed() bool {
- return true
-}
-
-// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.
-func (t *testPayload) MarshalUnsafe(dst []byte) {
- t.MarshalBytes(dst)
-}
-
-// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.
-func (t *testPayload) UnmarshalUnsafe(src []byte) {
- t.UnmarshalBytes(src)
-}
-
-// CopyOutN implements marshal.Marshallable.CopyOutN.
-func (t *testPayload) CopyOutN(task marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) {
- panic("not implemented")
-}
-
-// CopyOut implements marshal.Marshallable.CopyOut.
-func (t *testPayload) CopyOut(task marshal.CopyContext, addr hostarch.Addr) (int, error) {
- panic("not implemented")
-}
-
-// CopyIn implements marshal.Marshallable.CopyIn.
-func (t *testPayload) CopyIn(task marshal.CopyContext, addr hostarch.Addr) (int, error) {
- panic("not implemented")
-}
-
-// WriteTo implements io.WriterTo.WriteTo.
-func (t *testPayload) WriteTo(w io.Writer) (int64, error) {
- panic("not implemented")
-}
diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go
index fb213d109..09b148164 100644
--- a/pkg/sentry/loader/elf.go
+++ b/pkg/sentry/loader/elf.go
@@ -212,8 +212,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) {
phdrs := make([]elf.ProgHeader, hdr.Phnum)
for i := range phdrs {
var prog64 linux.ElfProg64
- prog64.UnmarshalUnsafe(phdrBuf[:prog64Size])
- phdrBuf = phdrBuf[prog64Size:]
+ phdrBuf = prog64.UnmarshalUnsafe(phdrBuf)
phdrs[i] = elf.ProgHeader{
Type: elf.ProgType(prog64.Type),
Flags: elf.ProgFlag(prog64.Flags),
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 6077b2150..4b036b323 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -17,6 +17,8 @@
package control
import (
+ "time"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/context"
@@ -29,7 +31,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
- "time"
)
// SCMCredentials represents a SCM_CREDENTIALS socket control message.
@@ -63,10 +64,10 @@ type RightsFiles []*fs.File
// NewSCMRights creates a new SCM_RIGHTS socket control message representation
// using local sentry FDs.
-func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) {
+func NewSCMRights(t *kernel.Task, fds []primitive.Int32) (SCMRights, error) {
files := make(RightsFiles, 0, len(fds))
for _, fd := range fds {
- file := t.GetFile(fd)
+ file := t.GetFile(int32(fd))
if file == nil {
files.Release(t)
return nil, linuxerr.EBADF
@@ -486,26 +487,25 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) (socket.ControlMessages, error) {
var (
cmsgs socket.ControlMessages
- fds linux.ControlMessageRights
+ fds []primitive.Int32
)
- for i := 0; i < len(buf); {
- if i+linux.SizeOfControlMessageHeader > len(buf) {
+ for len(buf) > 0 {
+ if linux.SizeOfControlMessageHeader > len(buf) {
return cmsgs, linuxerr.EINVAL
}
var h linux.ControlMessageHeader
- h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader])
+ buf = h.UnmarshalUnsafe(buf)
if h.Length < uint64(linux.SizeOfControlMessageHeader) {
return socket.ControlMessages{}, linuxerr.EINVAL
}
- if h.Length > uint64(len(buf)-i) {
- return socket.ControlMessages{}, linuxerr.EINVAL
- }
- i += linux.SizeOfControlMessageHeader
length := int(h.Length) - linux.SizeOfControlMessageHeader
+ if length > len(buf) {
+ return socket.ControlMessages{}, linuxerr.EINVAL
+ }
switch h.Level {
case linux.SOL_SOCKET:
@@ -518,11 +518,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
return socket.ControlMessages{}, linuxerr.EINVAL
}
- for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight {
- fds = append(fds, int32(hostarch.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
- }
-
- i += bits.AlignUp(length, width)
+ curFDs := make([]primitive.Int32, numRights)
+ primitive.UnmarshalUnsafeInt32Slice(curFDs, buf[:rightsSize])
+ fds = append(fds, curFDs...)
case linux.SCM_CREDENTIALS:
if length < linux.SizeOfControlMessageCredentials {
@@ -530,23 +528,21 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
}
var creds linux.ControlMessageCredentials
- creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials])
+ creds.UnmarshalUnsafe(buf)
scmCreds, err := NewSCMCredentials(t, creds)
if err != nil {
return socket.ControlMessages{}, err
}
cmsgs.Unix.Credentials = scmCreds
- i += bits.AlignUp(length, width)
case linux.SO_TIMESTAMP:
if length < linux.SizeOfTimeval {
return socket.ControlMessages{}, linuxerr.EINVAL
}
var ts linux.Timeval
- ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval])
+ ts.UnmarshalUnsafe(buf)
cmsgs.IP.Timestamp = ts.ToTime()
cmsgs.IP.HasTimestamp = true
- i += bits.AlignUp(length, width)
default:
// Unknown message type.
@@ -560,9 +556,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
}
cmsgs.IP.HasTOS = true
var tos primitive.Uint8
- tos.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTOS])
+ tos.UnmarshalUnsafe(buf)
cmsgs.IP.TOS = uint8(tos)
- i += bits.AlignUp(length, width)
case linux.IP_PKTINFO:
if length < linux.SizeOfControlMessageIPPacketInfo {
@@ -571,19 +566,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
cmsgs.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
- packetInfo.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageIPPacketInfo])
-
+ packetInfo.UnmarshalUnsafe(buf)
cmsgs.IP.PacketInfo = packetInfo
- i += bits.AlignUp(length, width)
case linux.IP_RECVORIGDSTADDR:
var addr linux.SockAddrInet
if length < addr.SizeBytes() {
return socket.ControlMessages{}, linuxerr.EINVAL
}
- addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()])
+ addr.UnmarshalUnsafe(buf)
cmsgs.IP.OriginalDstAddress = &addr
- i += bits.AlignUp(length, width)
case linux.IP_RECVERR:
var errCmsg linux.SockErrCMsgIPv4
@@ -591,9 +583,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
return socket.ControlMessages{}, linuxerr.EINVAL
}
- errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
+ errCmsg.UnmarshalBytes(buf)
cmsgs.IP.SockErr = &errCmsg
- i += bits.AlignUp(length, width)
default:
return socket.ControlMessages{}, linuxerr.EINVAL
@@ -606,18 +597,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
}
cmsgs.IP.HasTClass = true
var tclass primitive.Uint32
- tclass.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTClass])
+ tclass.UnmarshalUnsafe(buf)
cmsgs.IP.TClass = uint32(tclass)
- i += bits.AlignUp(length, width)
case linux.IPV6_RECVORIGDSTADDR:
var addr linux.SockAddrInet6
if length < addr.SizeBytes() {
return socket.ControlMessages{}, linuxerr.EINVAL
}
- addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()])
+ addr.UnmarshalUnsafe(buf)
cmsgs.IP.OriginalDstAddress = &addr
- i += bits.AlignUp(length, width)
case linux.IPV6_RECVERR:
var errCmsg linux.SockErrCMsgIPv6
@@ -625,9 +614,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
return socket.ControlMessages{}, linuxerr.EINVAL
}
- errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
+ errCmsg.UnmarshalBytes(buf)
cmsgs.IP.SockErr = &errCmsg
- i += bits.AlignUp(length, width)
default:
return socket.ControlMessages{}, linuxerr.EINVAL
@@ -635,6 +623,11 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
default:
return socket.ControlMessages{}, linuxerr.EINVAL
}
+ if shift := bits.AlignUp(length, width); shift > len(buf) {
+ buf = buf[:0]
+ } else {
+ buf = buf[shift:]
+ }
}
if cmsgs.Unix.Credentials == nil {
diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go
index 0a989cbeb..a638cb955 100644
--- a/pkg/sentry/socket/control/control_vfs2.go
+++ b/pkg/sentry/socket/control/control_vfs2.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -45,10 +46,10 @@ type RightsFilesVFS2 []*vfs.FileDescription
// NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message
// representation using local sentry FDs.
-func NewSCMRightsVFS2(t *kernel.Task, fds []int32) (SCMRightsVFS2, error) {
+func NewSCMRightsVFS2(t *kernel.Task, fds []primitive.Int32) (SCMRightsVFS2, error) {
files := make(RightsFilesVFS2, 0, len(fds))
for _, fd := range fds {
- file := t.GetFileVFS2(fd)
+ file := t.GetFileVFS2(int32(fd))
if file == nil {
files.Release(t)
return nil, linuxerr.EBADF
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 6e2318f75..a31f3ebec 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -578,7 +578,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
case linux.SO_TIMESTAMP:
controlMessages.IP.HasTimestamp = true
ts := linux.Timeval{}
- ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval])
+ ts.UnmarshalUnsafe(unixCmsg.Data)
controlMessages.IP.Timestamp = ts.ToTime()
}
@@ -587,18 +587,18 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
case linux.IP_TOS:
controlMessages.IP.HasTOS = true
var tos primitive.Uint8
- tos.UnmarshalUnsafe(unixCmsg.Data[:tos.SizeBytes()])
+ tos.UnmarshalUnsafe(unixCmsg.Data)
controlMessages.IP.TOS = uint8(tos)
case linux.IP_PKTINFO:
controlMessages.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
- packetInfo.UnmarshalUnsafe(unixCmsg.Data[:packetInfo.SizeBytes()])
+ packetInfo.UnmarshalUnsafe(unixCmsg.Data)
controlMessages.IP.PacketInfo = packetInfo
case linux.IP_RECVORIGDSTADDR:
var addr linux.SockAddrInet
- addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()])
+ addr.UnmarshalUnsafe(unixCmsg.Data)
controlMessages.IP.OriginalDstAddress = &addr
case unix.IP_RECVERR:
@@ -612,12 +612,12 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
case linux.IPV6_TCLASS:
controlMessages.IP.HasTClass = true
var tclass primitive.Uint32
- tclass.UnmarshalUnsafe(unixCmsg.Data[:tclass.SizeBytes()])
+ tclass.UnmarshalUnsafe(unixCmsg.Data)
controlMessages.IP.TClass = uint32(tclass)
case linux.IPV6_RECVORIGDSTADDR:
var addr linux.SockAddrInet6
- addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()])
+ addr.UnmarshalUnsafe(unixCmsg.Data)
controlMessages.IP.OriginalDstAddress = &addr
case unix.IPV6_RECVERR:
@@ -631,7 +631,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
case linux.TCP_INQ:
controlMessages.IP.HasInq = true
var inq primitive.Int32
- inq.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfControlMessageInq])
+ inq.UnmarshalUnsafe(unixCmsg.Data)
controlMessages.IP.Inq = int32(inq)
}
}
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index 61111ac6c..c84ab3fb7 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -138,7 +138,7 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli
return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), unix.SizeofIfInfomsg)
}
var ifinfo linux.InterfaceInfoMessage
- ifinfo.UnmarshalUnsafe(link.Data[:ifinfo.SizeBytes()])
+ ifinfo.UnmarshalUnsafe(link.Data)
inetIF := inet.Interface{
DeviceType: ifinfo.Type,
Flags: ifinfo.Flags,
@@ -169,7 +169,7 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli
return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), unix.SizeofIfAddrmsg)
}
var ifaddr linux.InterfaceAddrMessage
- ifaddr.UnmarshalUnsafe(addr.Data[:ifaddr.SizeBytes()])
+ ifaddr.UnmarshalUnsafe(addr.Data)
inetAddr := inet.InterfaceAddr{
Family: ifaddr.Family,
PrefixLen: ifaddr.PrefixLen,
@@ -201,7 +201,7 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error)
}
var ifRoute linux.RouteMessage
- ifRoute.UnmarshalUnsafe(routeMsg.Data[:ifRoute.SizeBytes()])
+ ifRoute.UnmarshalUnsafe(routeMsg.Data)
inetRoute := inet.Route{
Family: ifRoute.Family,
DstLen: ifRoute.DstLen,
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index 3f1b4a17b..7606d2bbb 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -80,9 +80,8 @@ func marshalEntryMatch(name string, data []byte) []byte {
copy(matcher.Name[:], name)
buf := make([]byte, size)
- entryLen := matcher.XTEntryMatch.SizeBytes()
- matcher.XTEntryMatch.MarshalUnsafe(buf[:entryLen])
- copy(buf[entryLen:], matcher.Data)
+ bufRemain := matcher.XTEntryMatch.MarshalUnsafe(buf)
+ copy(bufRemain, matcher.Data)
return buf
}
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
index af31cbc5b..6cbfee8b6 100644
--- a/pkg/sentry/socket/netfilter/ipv4.go
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -141,10 +141,9 @@ func modifyEntries4(task *kernel.Task, stk *stack.Stack, optVal []byte, replace
nflog("optVal has insufficient size for entry %d", len(optVal))
return nil, syserr.ErrInvalidArgument
}
- var entry linux.IPTEntry
- entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()])
initialOptValLen := len(optVal)
- optVal = optVal[entry.SizeBytes():]
+ var entry linux.IPTEntry
+ optVal = entry.UnmarshalUnsafe(optVal)
if entry.TargetOffset < linux.SizeOfIPTEntry {
nflog("entry has too-small target offset %d", entry.TargetOffset)
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
index 6cefe0b9c..902707abf 100644
--- a/pkg/sentry/socket/netfilter/ipv6.go
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -144,10 +144,9 @@ func modifyEntries6(task *kernel.Task, stk *stack.Stack, optVal []byte, replace
nflog("optVal has insufficient size for entry %d", len(optVal))
return nil, syserr.ErrInvalidArgument
}
- var entry linux.IP6TEntry
- entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()])
initialOptValLen := len(optVal)
- optVal = optVal[entry.SizeBytes():]
+ var entry linux.IP6TEntry
+ optVal = entry.UnmarshalUnsafe(optVal)
if entry.TargetOffset < linux.SizeOfIP6TEntry {
nflog("entry has too-small target offset %d", entry.TargetOffset)
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 01f2f8c77..f2cdf091d 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -176,9 +176,7 @@ func setHooksAndUnderflow(info *linux.IPTGetinfo, table stack.Table, offset uint
// net/ipv4/netfilter/ip_tables.c:translate_table for reference.
func SetEntries(task *kernel.Task, stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
var replace linux.IPTReplace
- replaceBuf := optVal[:linux.SizeOfIPTReplace]
- optVal = optVal[linux.SizeOfIPTReplace:]
- replace.UnmarshalBytes(replaceBuf)
+ optVal = replace.UnmarshalBytes(optVal)
var table stack.Table
switch replace.Name.String() {
@@ -306,8 +304,7 @@ func parseMatchers(task *kernel.Task, filter stack.IPHeaderFilter, optVal []byte
return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal))
}
var match linux.XTEntryMatch
- buf := optVal[:match.SizeBytes()]
- match.UnmarshalUnsafe(buf)
+ match.UnmarshalUnsafe(optVal)
nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match)
// Check some invariants.
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
index 6eff2ae65..d83d7e535 100644
--- a/pkg/sentry/socket/netfilter/owner_matcher.go
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -73,7 +73,7 @@ func (ownerMarshaler) unmarshal(task *kernel.Task, buf []byte, filter stack.IPHe
// For alignment reasons, the match's total size may
// exceed what's strictly necessary to hold matchData.
var matchData linux.IPTOwnerInfo
- matchData.UnmarshalUnsafe(buf[:linux.SizeOfIPTOwnerInfo])
+ matchData.UnmarshalUnsafe(buf)
nflog("parsed IPTOwnerInfo: %+v", matchData)
var owner OwnerMatcher
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index b9c15daab..eaf601543 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -199,7 +199,7 @@ func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return nil, syserr.ErrInvalidArgument
}
var standardTarget linux.XTStandardTarget
- standardTarget.UnmarshalUnsafe(buf[:standardTarget.SizeBytes()])
+ standardTarget.UnmarshalUnsafe(buf)
if standardTarget.Verdict < 0 {
// A Verdict < 0 indicates a non-jump verdict.
@@ -253,7 +253,6 @@ func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
return nil, syserr.ErrInvalidArgument
}
var errTgt linux.XTErrorTarget
- buf = buf[:linux.SizeOfXTErrorTarget]
errTgt.UnmarshalUnsafe(buf)
// Error targets are used in 2 cases:
@@ -316,7 +315,6 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
}
var rt linux.XTRedirectTarget
- buf = buf[:linux.SizeOfXTRedirectTarget]
rt.UnmarshalUnsafe(buf)
// Copy linux.XTRedirectTarget to stack.RedirectTarget.
@@ -405,8 +403,7 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
}
var natRange linux.NFNATRange
- buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize]
- natRange.UnmarshalUnsafe(buf)
+ natRange.UnmarshalUnsafe(buf[linux.SizeOfXTEntryTarget:])
// We don't support port or address ranges.
if natRange.MinAddr != natRange.MaxAddr {
@@ -477,7 +474,6 @@ func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta
}
var st linux.XTSNATTarget
- buf = buf[:linux.SizeOfXTSNATTarget]
st.UnmarshalUnsafe(buf)
// Copy linux.XTSNATTarget to stack.SNATTarget.
@@ -557,8 +553,7 @@ func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta
}
var natRange linux.NFNATRange
- buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize]
- natRange.UnmarshalUnsafe(buf)
+ natRange.UnmarshalUnsafe(buf[linux.SizeOfXTEntryTarget:])
// TODO(gvisor.dev/issue/5697): Support port or address ranges.
if natRange.MinAddr != natRange.MaxAddr {
@@ -621,7 +616,9 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.T
return nil, syserr.ErrInvalidArgument
}
var target linux.XTEntryTarget
- target.UnmarshalUnsafe(optVal[:target.SizeBytes()])
+ // Do not advance optVal as targetMake.unmarshal() may unmarshal
+ // XTEntryTarget again but with some added fields.
+ target.UnmarshalUnsafe(optVal)
return unmarshalTarget(target, filter, optVal)
}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index e5b73a976..a621a6a16 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -59,7 +59,7 @@ func (tcpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderF
// For alignment reasons, the match's total size may
// exceed what's strictly necessary to hold matchData.
var matchData linux.XTTCP
- matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()])
+ matchData.UnmarshalUnsafe(buf)
nflog("parseMatchers: parsed XTTCP: %+v", matchData)
if matchData.Option != 0 ||
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index aa72ee70c..2ca854764 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -59,7 +59,7 @@ func (udpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderF
// For alignment reasons, the match's total size may exceed what's
// strictly necessary to hold matchData.
var matchData linux.XTUDP
- matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()])
+ matchData.UnmarshalUnsafe(buf)
nflog("parseMatchers: parsed XTUDP: %+v", matchData)
if matchData.InverseFlags != 0 {
diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go
index 968968469..1604b2792 100644
--- a/pkg/sentry/socket/netlink/message_test.go
+++ b/pkg/sentry/socket/netlink/message_test.go
@@ -25,33 +25,14 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
)
-type dummyNetlinkMsg struct {
- marshal.StubMarshallable
- Foo uint16
-}
-
-func (*dummyNetlinkMsg) SizeBytes() int {
- return 2
-}
-
-func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) {
- p := primitive.Uint16(m.Foo)
- p.MarshalUnsafe(dst)
-}
-
-func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) {
- var p primitive.Uint16
- p.UnmarshalUnsafe(src)
- m.Foo = uint16(p)
-}
-
func TestParseMessage(t *testing.T) {
+ dummyNetlinkMsg := primitive.Uint16(0x3130)
tests := []struct {
desc string
input []byte
header linux.NetlinkMessageHeader
- dataMsg *dummyNetlinkMsg
+ dataMsg marshal.Marshallable
restLen int
ok bool
}{
@@ -72,9 +53,7 @@ func TestParseMessage(t *testing.T) {
Seq: 3,
PortID: 4,
},
- dataMsg: &dummyNetlinkMsg{
- Foo: 0x3130,
- },
+ dataMsg: &dummyNetlinkMsg,
restLen: 0,
ok: true,
},
@@ -96,9 +75,7 @@ func TestParseMessage(t *testing.T) {
Seq: 3,
PortID: 4,
},
- dataMsg: &dummyNetlinkMsg{
- Foo: 0x3130,
- },
+ dataMsg: &dummyNetlinkMsg,
restLen: 1,
ok: true,
},
@@ -119,9 +96,7 @@ func TestParseMessage(t *testing.T) {
Seq: 3,
PortID: 4,
},
- dataMsg: &dummyNetlinkMsg{
- Foo: 0x3130,
- },
+ dataMsg: &dummyNetlinkMsg,
restLen: 0,
ok: true,
},
@@ -143,9 +118,7 @@ func TestParseMessage(t *testing.T) {
Seq: 3,
PortID: 4,
},
- dataMsg: &dummyNetlinkMsg{
- Foo: 0x3130,
- },
+ dataMsg: &dummyNetlinkMsg,
restLen: 0,
ok: true,
},
@@ -199,11 +172,11 @@ func TestParseMessage(t *testing.T) {
t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header)
}
- dataMsg := &dummyNetlinkMsg{}
- _, dataOk := msg.GetData(dataMsg)
+ var dataMsg primitive.Uint16
+ _, dataOk := msg.GetData(&dataMsg)
if !dataOk {
t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk)
- } else if !reflect.DeepEqual(dataMsg, test.dataMsg) {
+ } else if !reflect.DeepEqual(&dataMsg, test.dataMsg) {
t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg)
}
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 267155807..19c8f340d 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -223,7 +223,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
}
var sa linux.SockAddrNetlink
- sa.UnmarshalUnsafe(b[:sa.SizeBytes()])
+ sa.UnmarshalUnsafe(b)
if sa.Family != linux.AF_NETLINK {
return nil, syserr.ErrInvalidArgument
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index c35cf06f6..e38c4e5da 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -660,7 +660,7 @@ func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error {
if len(sockaddr) < sockAddrLinkSize {
return syserr.ErrInvalidArgument
}
- a.UnmarshalBytes(sockaddr[:sockAddrLinkSize])
+ a.UnmarshalBytes(sockaddr)
addr = tcpip.FullAddress{
NIC: tcpip.NICID(a.InterfaceIndex),
@@ -1839,7 +1839,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
var v linux.Timeval
- v.UnmarshalBytes(optVal[:linux.SizeOfTimeval])
+ v.UnmarshalBytes(optVal)
if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
return syserr.ErrDomain
}
@@ -1852,7 +1852,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
var v linux.Timeval
- v.UnmarshalBytes(optVal[:linux.SizeOfTimeval])
+ v.UnmarshalBytes(optVal)
if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
return syserr.ErrDomain
}
@@ -1883,7 +1883,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
var v linux.Linger
- v.UnmarshalBytes(optVal[:linux.SizeOfLinger])
+ v.UnmarshalBytes(optVal)
if v != (linux.Linger{}) {
socket.SetSockOptEmitUnimplementedEvent(t, name)
@@ -2222,12 +2222,12 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR
if len(optVal) >= inetMulticastRequestWithNICSize {
var req linux.InetMulticastRequestWithNIC
- req.UnmarshalUnsafe(optVal[:inetMulticastRequestWithNICSize])
+ req.UnmarshalUnsafe(optVal)
return req, nil
}
var req linux.InetMulticastRequestWithNIC
- req.InetMulticastRequest.UnmarshalUnsafe(optVal[:inetMulticastRequestSize])
+ req.InetMulticastRequest.UnmarshalUnsafe(optVal)
return req, nil
}
@@ -2237,7 +2237,7 @@ func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syse
}
var req linux.Inet6MulticastRequest
- req.UnmarshalUnsafe(optVal[:inet6MulticastRequestSize])
+ req.UnmarshalUnsafe(optVal)
return req, nil
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index fc5431eb1..01073df72 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -595,19 +595,19 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr {
switch family {
case unix.AF_INET:
var addr linux.SockAddrInet
- addr.UnmarshalUnsafe(data[:addr.SizeBytes()])
+ addr.UnmarshalUnsafe(data)
return &addr
case unix.AF_INET6:
var addr linux.SockAddrInet6
- addr.UnmarshalUnsafe(data[:addr.SizeBytes()])
+ addr.UnmarshalUnsafe(data)
return &addr
case unix.AF_UNIX:
var addr linux.SockAddrUnix
- addr.UnmarshalUnsafe(data[:addr.SizeBytes()])
+ addr.UnmarshalUnsafe(data)
return &addr
case unix.AF_NETLINK:
var addr linux.SockAddrNetlink
- addr.UnmarshalUnsafe(data[:addr.SizeBytes()])
+ addr.UnmarshalUnsafe(data)
return &addr
default:
panic(fmt.Sprintf("Unsupported socket family %v", family))
@@ -738,7 +738,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
if len(addr) < sockAddrInetSize {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- a.UnmarshalUnsafe(addr[:sockAddrInetSize])
+ a.UnmarshalUnsafe(addr)
out := tcpip.FullAddress{
Addr: BytesToIPAddress(a.Addr[:]),
@@ -751,7 +751,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
if len(addr) < sockAddrInet6Size {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- a.UnmarshalUnsafe(addr[:sockAddrInet6Size])
+ a.UnmarshalUnsafe(addr)
out := tcpip.FullAddress{
Addr: BytesToIPAddress(a.Addr[:]),
@@ -767,7 +767,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
if len(addr) < sockAddrLinkSize {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- a.UnmarshalUnsafe(addr[:sockAddrLinkSize])
+ a.UnmarshalUnsafe(addr)
// TODO(https://gvisor.dev/issue/6530): Do not assume all interfaces have
// an ethernet address.
if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go
index f4aab25b0..b164d9107 100644
--- a/pkg/sentry/strace/socket.go
+++ b/pkg/sentry/strace/socket.go
@@ -161,12 +161,10 @@ var controlMessageType = map[int32]string{
linux.SO_TIMESTAMP: "SO_TIMESTAMP",
}
-func unmarshalControlMessageRights(src []byte) linux.ControlMessageRights {
+func unmarshalControlMessageRights(src []byte) []primitive.Int32 {
count := len(src) / linux.SizeOfControlMessageRight
- cmr := make(linux.ControlMessageRights, count)
- for i, _ := range cmr {
- cmr[i] = int32(hostarch.ByteOrder.Uint32(src[i*linux.SizeOfControlMessageRight:]))
- }
+ cmr := make([]primitive.Int32, count)
+ primitive.UnmarshalUnsafeInt32Slice(cmr, src)
return cmr
}
@@ -182,14 +180,14 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
var strs []string
- for i := 0; i < len(buf); {
- if i+linux.SizeOfControlMessageHeader > len(buf) {
+ for len(buf) > 0 {
+ if linux.SizeOfControlMessageHeader > len(buf) {
strs = append(strs, "{invalid control message (too short)}")
break
}
var h linux.ControlMessageHeader
- h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader])
+ buf = h.UnmarshalUnsafe(buf)
var skipData bool
level := "SOL_SOCKET"
@@ -204,7 +202,9 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
typ = fmt.Sprint(h.Type)
}
- if h.Length > uint64(len(buf)-i) {
+ width := t.Arch().Width()
+ length := int(h.Length) - linux.SizeOfControlMessageHeader
+ if length > len(buf) {
strs = append(strs, fmt.Sprintf(
"{level=%s, type=%s, length=%d, content extends beyond buffer}",
level,
@@ -214,9 +214,6 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
break
}
- i += linux.SizeOfControlMessageHeader
- width := t.Arch().Width()
- length := int(h.Length) - linux.SizeOfControlMessageHeader
if length < 0 {
strs = append(strs, fmt.Sprintf(
"{level=%s, type=%s, length=%d, content too short}",
@@ -229,78 +226,80 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64)
if skipData {
strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length))
- i += bits.AlignUp(length, width)
- continue
- }
-
- switch h.Type {
- case linux.SCM_RIGHTS:
- rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight)
- fds := unmarshalControlMessageRights(buf[i : i+rightsSize])
- rights := make([]string, 0, len(fds))
- for _, fd := range fds {
- rights = append(rights, fmt.Sprint(fd))
- }
+ } else {
+ switch h.Type {
+ case linux.SCM_RIGHTS:
+ rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight)
+ fds := unmarshalControlMessageRights(buf[:rightsSize])
+ rights := make([]string, 0, len(fds))
+ for _, fd := range fds {
+ rights = append(rights, fmt.Sprint(fd))
+ }
- strs = append(strs, fmt.Sprintf(
- "{level=%s, type=%s, length=%d, content: %s}",
- level,
- typ,
- h.Length,
- strings.Join(rights, ","),
- ))
-
- case linux.SCM_CREDENTIALS:
- if length < linux.SizeOfControlMessageCredentials {
strs = append(strs, fmt.Sprintf(
- "{level=%s, type=%s, length=%d, content too short}",
+ "{level=%s, type=%s, length=%d, content: %s}",
level,
typ,
h.Length,
+ strings.Join(rights, ","),
))
- break
- }
- var creds linux.ControlMessageCredentials
- creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials])
+ case linux.SCM_CREDENTIALS:
+ if length < linux.SizeOfControlMessageCredentials {
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, content too short}",
+ level,
+ typ,
+ h.Length,
+ ))
+ break
+ }
- strs = append(strs, fmt.Sprintf(
- "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}",
- level,
- typ,
- h.Length,
- creds.PID,
- creds.UID,
- creds.GID,
- ))
+ var creds linux.ControlMessageCredentials
+ creds.UnmarshalUnsafe(buf)
- case linux.SO_TIMESTAMP:
- if length < linux.SizeOfTimeval {
strs = append(strs, fmt.Sprintf(
- "{level=%s, type=%s, length=%d, content too short}",
+ "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}",
level,
typ,
h.Length,
+ creds.PID,
+ creds.UID,
+ creds.GID,
))
- break
- }
- var tv linux.Timeval
- tv.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval])
+ case linux.SO_TIMESTAMP:
+ if length < linux.SizeOfTimeval {
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, content too short}",
+ level,
+ typ,
+ h.Length,
+ ))
+ break
+ }
- strs = append(strs, fmt.Sprintf(
- "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}",
- level,
- typ,
- h.Length,
- tv.Sec,
- tv.Usec,
- ))
+ var tv linux.Timeval
+ tv.UnmarshalUnsafe(buf)
- default:
- panic("unreachable")
+ strs = append(strs, fmt.Sprintf(
+ "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}",
+ level,
+ typ,
+ h.Length,
+ tv.Sec,
+ tv.Usec,
+ ))
+
+ default:
+ panic("unreachable")
+ }
+ }
+ if shift := bits.AlignUp(length, width); shift > len(buf) {
+ buf = buf[:0]
+ } else {
+ buf = buf[shift:]
}
- i += bits.AlignUp(length, width)
}
return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", "))