From ce4f4283badb6b07baf9f8e6d99e7a5fd15c92db Mon Sep 17 00:00:00 2001 From: Ayush Ranjan Date: Fri, 5 Nov 2021 10:40:53 -0700 Subject: Make {Un}Marshal{Bytes/Unsafe} return remaining buffer. Change marshal.Marshallable method signatures to return the remaining buffer. This makes it easier to implement these method manually. Without this, we would have to manually do buffer shifting which is error prone. tools/go_marshal/test:benchmark test does not show change in performance. Additionally fixed some marshalling bugs in fsimpl/fuse. Updated multiple callpoints to get rid of redundant slice indexing work and simplified code using this new signature. Updates #6450 PiperOrigin-RevId: 407857019 --- pkg/sentry/fsimpl/fuse/BUILD | 3 +- pkg/sentry/fsimpl/fuse/connection_test.go | 11 +-- pkg/sentry/fsimpl/fuse/dev_test.go | 19 ++-- pkg/sentry/fsimpl/fuse/fusefs.go | 16 ++-- pkg/sentry/fsimpl/fuse/request_response.go | 4 +- pkg/sentry/fsimpl/fuse/utils_test.go | 59 ------------ pkg/sentry/loader/elf.go | 3 +- pkg/sentry/socket/control/control.go | 63 ++++++------- pkg/sentry/socket/control/control_vfs2.go | 5 +- pkg/sentry/socket/hostinet/socket.go | 14 +-- pkg/sentry/socket/hostinet/stack.go | 6 +- pkg/sentry/socket/netfilter/extensions.go | 5 +- pkg/sentry/socket/netfilter/ipv4.go | 5 +- pkg/sentry/socket/netfilter/ipv6.go | 5 +- pkg/sentry/socket/netfilter/netfilter.go | 7 +- pkg/sentry/socket/netfilter/owner_matcher.go | 2 +- pkg/sentry/socket/netfilter/targets.go | 15 ++- pkg/sentry/socket/netfilter/tcp_matcher.go | 2 +- pkg/sentry/socket/netfilter/udp_matcher.go | 2 +- pkg/sentry/socket/netlink/message_test.go | 45 ++------- pkg/sentry/socket/netlink/socket.go | 2 +- pkg/sentry/socket/netstack/netstack.go | 14 +-- pkg/sentry/socket/socket.go | 14 +-- pkg/sentry/strace/socket.go | 131 +++++++++++++-------------- 24 files changed, 172 insertions(+), 280 deletions(-) (limited to 'pkg/sentry') diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 05c4fbeb2..18497a880 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -77,8 +77,7 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go index 1fddd858e..d98d2832b 100644 --- a/pkg/sentry/fsimpl/fuse/connection_test.go +++ b/pkg/sentry/fsimpl/fuse/connection_test.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ) @@ -69,14 +70,10 @@ func TestConnectionAbort(t *testing.T) { t.Fatalf("newTestConnection: %v", err) } - testObj := &testPayload{ - data: rand.Uint32(), - } - var futNormal []*futureResponse - + testObj := primitive.Uint32(rand.Uint32()) for i := 0; i < int(numRequests); i++ { - req := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj) + req := conn.NewRequest(creds, uint32(i), uint64(i), 0, &testObj) fut, err := conn.callFutureLocked(task, req) if err != nil { t.Fatalf("callFutureLocked failed: %v", err) @@ -102,7 +99,7 @@ func TestConnectionAbort(t *testing.T) { } // After abort, Call() should return directly with ENOTCONN. - req := conn.NewRequest(creds, 0, 0, 0, testObj) + req := conn.NewRequest(creds, 0, 0, 0, &testObj) _, err = conn.Call(task, req) if !linuxerr.Equals(linuxerr.ENOTCONN, err) { t.Fatalf("Incorrect error code received for Call() after connection aborted") diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go index 8951b5ba8..13b32fc7c 100644 --- a/pkg/sentry/fsimpl/fuse/dev_test.go +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -215,11 +216,9 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con if err != nil { t.Fatal(err) } - testObj := &testPayload{ - data: rand.Uint32(), - } - req := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj) + testObj := primitive.Uint32(rand.Uint32()) + req := conn.NewRequest(creds, pid, inode, echoTestOpcode, &testObj) // Queue up a request. // Analogous to Call except it doesn't block on the task. @@ -232,7 +231,7 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con t.Fatalf("Server responded with an error: %v", err) } - var respTestPayload testPayload + var respTestPayload primitive.Uint32 if err := resp.UnmarshalPayload(&respTestPayload); err != nil { t.Fatalf("Unmarshalling payload error: %v", err) } @@ -242,8 +241,8 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con req.hdr.Unique, resp.hdr.Unique) } - if respTestPayload.data != testObj.data { - t.Fatalf("read incorrect data. Data expected: %v, but got %v", testObj.data, respTestPayload.data) + if respTestPayload != testObj { + t.Fatalf("read incorrect data. Data expected: %d, but got %d", testObj, respTestPayload) } } @@ -256,8 +255,8 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F // Create the tasks that the server will be using. tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) - var readPayload testPayload + var readPayload primitive.Uint32 serverTask, err := testutil.CreateTask(s.Ctx, "fuse-server", tc, s.MntNs, s.Root, s.Root) if err != nil { t.Fatal(err) @@ -291,8 +290,8 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F } var readFUSEHeaderIn linux.FUSEHeaderIn - readFUSEHeaderIn.UnmarshalUnsafe(inBuf[:inHdrLen]) - readPayload.UnmarshalUnsafe(inBuf[inHdrLen : inHdrLen+payloadLen]) + inBuf = readFUSEHeaderIn.UnmarshalUnsafe(inBuf) + readPayload.UnmarshalUnsafe(inBuf) if readFUSEHeaderIn.Opcode != echoTestOpcode { t.Fatalf("read incorrect data. Header: %v, Payload: %v", readFUSEHeaderIn, readPayload) diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index af16098d2..00b520c31 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -489,7 +489,7 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr // Lookup implements kernfs.Inode.Lookup. func (i *inode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) { - in := linux.FUSELookupIn{Name: name} + in := linux.FUSELookupIn{Name: linux.CString(name)} return i.newEntry(ctx, name, 0, linux.FUSE_LOOKUP, &in) } @@ -520,7 +520,7 @@ func (i *inode) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) Mode: uint32(opts.Mode) | linux.S_IFREG, Umask: uint32(kernelTask.FSContext().Umask()), }, - Name: name, + Name: linux.CString(name), } return i.newEntry(ctx, name, linux.S_IFREG, linux.FUSE_CREATE, &in) } @@ -533,7 +533,7 @@ func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) Rdev: linux.MakeDeviceID(uint16(opts.DevMajor), opts.DevMinor), Umask: uint32(kernel.TaskFromContext(ctx).FSContext().Umask()), }, - Name: name, + Name: linux.CString(name), } return i.newEntry(ctx, name, opts.Mode.FileType(), linux.FUSE_MKNOD, &in) } @@ -541,8 +541,8 @@ func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) // NewSymlink implements kernfs.Inode.NewSymlink. func (i *inode) NewSymlink(ctx context.Context, name, target string) (kernfs.Inode, error) { in := linux.FUSESymLinkIn{ - Name: name, - Target: target, + Name: linux.CString(name), + Target: linux.CString(target), } return i.newEntry(ctx, name, linux.S_IFLNK, linux.FUSE_SYMLINK, &in) } @@ -554,7 +554,7 @@ func (i *inode) Unlink(ctx context.Context, name string, child kernfs.Inode) err log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID) return linuxerr.EINVAL } - in := linux.FUSEUnlinkIn{Name: name} + in := linux.FUSEUnlinkIn{Name: linux.CString(name)} req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in) res, err := i.fs.conn.Call(kernelTask, req) if err != nil { @@ -571,7 +571,7 @@ func (i *inode) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) Mode: uint32(opts.Mode), Umask: uint32(kernel.TaskFromContext(ctx).FSContext().Umask()), }, - Name: name, + Name: linux.CString(name), } return i.newEntry(ctx, name, linux.S_IFDIR, linux.FUSE_MKDIR, &in) } @@ -581,7 +581,7 @@ func (i *inode) RmDir(ctx context.Context, name string, child kernfs.Inode) erro fusefs := i.fs task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx) - in := linux.FUSERmDirIn{Name: name} + in := linux.FUSERmDirIn{Name: linux.CString(name)} req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in) res, err := i.fs.conn.Call(task, req) if err != nil { diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go index 8a72489fa..ec76ec2a4 100644 --- a/pkg/sentry/fsimpl/fuse/request_response.go +++ b/pkg/sentry/fsimpl/fuse/request_response.go @@ -41,7 +41,7 @@ type fuseInitRes struct { } // UnmarshalBytes deserializes src to the initOut attribute in a fuseInitRes. -func (r *fuseInitRes) UnmarshalBytes(src []byte) { +func (r *fuseInitRes) UnmarshalBytes(src []byte) []byte { out := &r.initOut // Introduced before FUSE kernel version 7.13. @@ -70,7 +70,7 @@ func (r *fuseInitRes) UnmarshalBytes(src []byte) { out.MaxPages = uint16(hostarch.ByteOrder.Uint16(src[:2])) src = src[2:] } - _ = src // Remove unused warning. + return src } // SizeBytes is the size of the payload of the FUSE_INIT response. diff --git a/pkg/sentry/fsimpl/fuse/utils_test.go b/pkg/sentry/fsimpl/fuse/utils_test.go index b0bab0066..8d4a2fad3 100644 --- a/pkg/sentry/fsimpl/fuse/utils_test.go +++ b/pkg/sentry/fsimpl/fuse/utils_test.go @@ -15,17 +15,13 @@ package fuse import ( - "io" "testing" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - - "gvisor.dev/gvisor/pkg/hostarch" ) func setup(t *testing.T) *testutil.System { @@ -70,58 +66,3 @@ func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveReque } return fs.conn, &fuseDev.vfsfd, nil } - -type testPayload struct { - marshal.StubMarshallable - data uint32 -} - -// SizeBytes implements marshal.Marshallable.SizeBytes. -func (t *testPayload) SizeBytes() int { - return 4 -} - -// MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (t *testPayload) MarshalBytes(dst []byte) { - hostarch.ByteOrder.PutUint32(dst[:4], t.data) -} - -// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (t *testPayload) UnmarshalBytes(src []byte) { - *t = testPayload{data: hostarch.ByteOrder.Uint32(src[:4])} -} - -// Packed implements marshal.Marshallable.Packed. -func (t *testPayload) Packed() bool { - return true -} - -// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. -func (t *testPayload) MarshalUnsafe(dst []byte) { - t.MarshalBytes(dst) -} - -// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. -func (t *testPayload) UnmarshalUnsafe(src []byte) { - t.UnmarshalBytes(src) -} - -// CopyOutN implements marshal.Marshallable.CopyOutN. -func (t *testPayload) CopyOutN(task marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { - panic("not implemented") -} - -// CopyOut implements marshal.Marshallable.CopyOut. -func (t *testPayload) CopyOut(task marshal.CopyContext, addr hostarch.Addr) (int, error) { - panic("not implemented") -} - -// CopyIn implements marshal.Marshallable.CopyIn. -func (t *testPayload) CopyIn(task marshal.CopyContext, addr hostarch.Addr) (int, error) { - panic("not implemented") -} - -// WriteTo implements io.WriterTo.WriteTo. -func (t *testPayload) WriteTo(w io.Writer) (int64, error) { - panic("not implemented") -} diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index fb213d109..09b148164 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -212,8 +212,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { phdrs := make([]elf.ProgHeader, hdr.Phnum) for i := range phdrs { var prog64 linux.ElfProg64 - prog64.UnmarshalUnsafe(phdrBuf[:prog64Size]) - phdrBuf = phdrBuf[prog64Size:] + phdrBuf = prog64.UnmarshalUnsafe(phdrBuf) phdrs[i] = elf.ProgHeader{ Type: elf.ProgType(prog64.Type), Flags: elf.ProgFlag(prog64.Flags), diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 6077b2150..4b036b323 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -17,6 +17,8 @@ package control import ( + "time" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/context" @@ -29,7 +31,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "time" ) // SCMCredentials represents a SCM_CREDENTIALS socket control message. @@ -63,10 +64,10 @@ type RightsFiles []*fs.File // NewSCMRights creates a new SCM_RIGHTS socket control message representation // using local sentry FDs. -func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) { +func NewSCMRights(t *kernel.Task, fds []primitive.Int32) (SCMRights, error) { files := make(RightsFiles, 0, len(fds)) for _, fd := range fds { - file := t.GetFile(fd) + file := t.GetFile(int32(fd)) if file == nil { files.Release(t) return nil, linuxerr.EBADF @@ -486,26 +487,25 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) (socket.ControlMessages, error) { var ( cmsgs socket.ControlMessages - fds linux.ControlMessageRights + fds []primitive.Int32 ) - for i := 0; i < len(buf); { - if i+linux.SizeOfControlMessageHeader > len(buf) { + for len(buf) > 0 { + if linux.SizeOfControlMessageHeader > len(buf) { return cmsgs, linuxerr.EINVAL } var h linux.ControlMessageHeader - h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) + buf = h.UnmarshalUnsafe(buf) if h.Length < uint64(linux.SizeOfControlMessageHeader) { return socket.ControlMessages{}, linuxerr.EINVAL } - if h.Length > uint64(len(buf)-i) { - return socket.ControlMessages{}, linuxerr.EINVAL - } - i += linux.SizeOfControlMessageHeader length := int(h.Length) - linux.SizeOfControlMessageHeader + if length > len(buf) { + return socket.ControlMessages{}, linuxerr.EINVAL + } switch h.Level { case linux.SOL_SOCKET: @@ -518,11 +518,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, linuxerr.EINVAL } - for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight { - fds = append(fds, int32(hostarch.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) - } - - i += bits.AlignUp(length, width) + curFDs := make([]primitive.Int32, numRights) + primitive.UnmarshalUnsafeInt32Slice(curFDs, buf[:rightsSize]) + fds = append(fds, curFDs...) case linux.SCM_CREDENTIALS: if length < linux.SizeOfControlMessageCredentials { @@ -530,23 +528,21 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var creds linux.ControlMessageCredentials - creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) + creds.UnmarshalUnsafe(buf) scmCreds, err := NewSCMCredentials(t, creds) if err != nil { return socket.ControlMessages{}, err } cmsgs.Unix.Credentials = scmCreds - i += bits.AlignUp(length, width) case linux.SO_TIMESTAMP: if length < linux.SizeOfTimeval { return socket.ControlMessages{}, linuxerr.EINVAL } var ts linux.Timeval - ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) + ts.UnmarshalUnsafe(buf) cmsgs.IP.Timestamp = ts.ToTime() cmsgs.IP.HasTimestamp = true - i += bits.AlignUp(length, width) default: // Unknown message type. @@ -560,9 +556,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } cmsgs.IP.HasTOS = true var tos primitive.Uint8 - tos.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTOS]) + tos.UnmarshalUnsafe(buf) cmsgs.IP.TOS = uint8(tos) - i += bits.AlignUp(length, width) case linux.IP_PKTINFO: if length < linux.SizeOfControlMessageIPPacketInfo { @@ -571,19 +566,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) cmsgs.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - packetInfo.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageIPPacketInfo]) - + packetInfo.UnmarshalUnsafe(buf) cmsgs.IP.PacketInfo = packetInfo - i += bits.AlignUp(length, width) case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet if length < addr.SizeBytes() { return socket.ControlMessages{}, linuxerr.EINVAL } - addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) + addr.UnmarshalUnsafe(buf) cmsgs.IP.OriginalDstAddress = &addr - i += bits.AlignUp(length, width) case linux.IP_RECVERR: var errCmsg linux.SockErrCMsgIPv4 @@ -591,9 +583,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, linuxerr.EINVAL } - errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + errCmsg.UnmarshalBytes(buf) cmsgs.IP.SockErr = &errCmsg - i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, linuxerr.EINVAL @@ -606,18 +597,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } cmsgs.IP.HasTClass = true var tclass primitive.Uint32 - tclass.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTClass]) + tclass.UnmarshalUnsafe(buf) cmsgs.IP.TClass = uint32(tclass) - i += bits.AlignUp(length, width) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 if length < addr.SizeBytes() { return socket.ControlMessages{}, linuxerr.EINVAL } - addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) + addr.UnmarshalUnsafe(buf) cmsgs.IP.OriginalDstAddress = &addr - i += bits.AlignUp(length, width) case linux.IPV6_RECVERR: var errCmsg linux.SockErrCMsgIPv6 @@ -625,9 +614,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, linuxerr.EINVAL } - errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + errCmsg.UnmarshalBytes(buf) cmsgs.IP.SockErr = &errCmsg - i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, linuxerr.EINVAL @@ -635,6 +623,11 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) default: return socket.ControlMessages{}, linuxerr.EINVAL } + if shift := bits.AlignUp(length, width); shift > len(buf) { + buf = buf[:0] + } else { + buf = buf[shift:] + } } if cmsgs.Unix.Credentials == nil { diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go index 0a989cbeb..a638cb955 100644 --- a/pkg/sentry/socket/control/control_vfs2.go +++ b/pkg/sentry/socket/control/control_vfs2.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -45,10 +46,10 @@ type RightsFilesVFS2 []*vfs.FileDescription // NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message // representation using local sentry FDs. -func NewSCMRightsVFS2(t *kernel.Task, fds []int32) (SCMRightsVFS2, error) { +func NewSCMRightsVFS2(t *kernel.Task, fds []primitive.Int32) (SCMRightsVFS2, error) { files := make(RightsFilesVFS2, 0, len(fds)) for _, fd := range fds { - file := t.GetFileVFS2(fd) + file := t.GetFileVFS2(int32(fd)) if file == nil { files.Release(t) return nil, linuxerr.EBADF diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 6e2318f75..a31f3ebec 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -578,7 +578,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.SO_TIMESTAMP: controlMessages.IP.HasTimestamp = true ts := linux.Timeval{} - ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval]) + ts.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.Timestamp = ts.ToTime() } @@ -587,18 +587,18 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.IP_TOS: controlMessages.IP.HasTOS = true var tos primitive.Uint8 - tos.UnmarshalUnsafe(unixCmsg.Data[:tos.SizeBytes()]) + tos.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.TOS = uint8(tos) case linux.IP_PKTINFO: controlMessages.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - packetInfo.UnmarshalUnsafe(unixCmsg.Data[:packetInfo.SizeBytes()]) + packetInfo.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.PacketInfo = packetInfo case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet - addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.OriginalDstAddress = &addr case unix.IP_RECVERR: @@ -612,12 +612,12 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.IPV6_TCLASS: controlMessages.IP.HasTClass = true var tclass primitive.Uint32 - tclass.UnmarshalUnsafe(unixCmsg.Data[:tclass.SizeBytes()]) + tclass.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.TClass = uint32(tclass) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 - addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.OriginalDstAddress = &addr case unix.IPV6_RECVERR: @@ -631,7 +631,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.TCP_INQ: controlMessages.IP.HasInq = true var inq primitive.Int32 - inq.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfControlMessageInq]) + inq.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.Inq = int32(inq) } } diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 61111ac6c..c84ab3fb7 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -138,7 +138,7 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), unix.SizeofIfInfomsg) } var ifinfo linux.InterfaceInfoMessage - ifinfo.UnmarshalUnsafe(link.Data[:ifinfo.SizeBytes()]) + ifinfo.UnmarshalUnsafe(link.Data) inetIF := inet.Interface{ DeviceType: ifinfo.Type, Flags: ifinfo.Flags, @@ -169,7 +169,7 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), unix.SizeofIfAddrmsg) } var ifaddr linux.InterfaceAddrMessage - ifaddr.UnmarshalUnsafe(addr.Data[:ifaddr.SizeBytes()]) + ifaddr.UnmarshalUnsafe(addr.Data) inetAddr := inet.InterfaceAddr{ Family: ifaddr.Family, PrefixLen: ifaddr.PrefixLen, @@ -201,7 +201,7 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) } var ifRoute linux.RouteMessage - ifRoute.UnmarshalUnsafe(routeMsg.Data[:ifRoute.SizeBytes()]) + ifRoute.UnmarshalUnsafe(routeMsg.Data) inetRoute := inet.Route{ Family: ifRoute.Family, DstLen: ifRoute.DstLen, diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 3f1b4a17b..7606d2bbb 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -80,9 +80,8 @@ func marshalEntryMatch(name string, data []byte) []byte { copy(matcher.Name[:], name) buf := make([]byte, size) - entryLen := matcher.XTEntryMatch.SizeBytes() - matcher.XTEntryMatch.MarshalUnsafe(buf[:entryLen]) - copy(buf[entryLen:], matcher.Data) + bufRemain := matcher.XTEntryMatch.MarshalUnsafe(buf) + copy(bufRemain, matcher.Data) return buf } diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index af31cbc5b..6cbfee8b6 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -141,10 +141,9 @@ func modifyEntries4(task *kernel.Task, stk *stack.Stack, optVal []byte, replace nflog("optVal has insufficient size for entry %d", len(optVal)) return nil, syserr.ErrInvalidArgument } - var entry linux.IPTEntry - entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()]) initialOptValLen := len(optVal) - optVal = optVal[entry.SizeBytes():] + var entry linux.IPTEntry + optVal = entry.UnmarshalUnsafe(optVal) if entry.TargetOffset < linux.SizeOfIPTEntry { nflog("entry has too-small target offset %d", entry.TargetOffset) diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 6cefe0b9c..902707abf 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -144,10 +144,9 @@ func modifyEntries6(task *kernel.Task, stk *stack.Stack, optVal []byte, replace nflog("optVal has insufficient size for entry %d", len(optVal)) return nil, syserr.ErrInvalidArgument } - var entry linux.IP6TEntry - entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()]) initialOptValLen := len(optVal) - optVal = optVal[entry.SizeBytes():] + var entry linux.IP6TEntry + optVal = entry.UnmarshalUnsafe(optVal) if entry.TargetOffset < linux.SizeOfIP6TEntry { nflog("entry has too-small target offset %d", entry.TargetOffset) diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 01f2f8c77..f2cdf091d 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -176,9 +176,7 @@ func setHooksAndUnderflow(info *linux.IPTGetinfo, table stack.Table, offset uint // net/ipv4/netfilter/ip_tables.c:translate_table for reference. func SetEntries(task *kernel.Task, stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { var replace linux.IPTReplace - replaceBuf := optVal[:linux.SizeOfIPTReplace] - optVal = optVal[linux.SizeOfIPTReplace:] - replace.UnmarshalBytes(replaceBuf) + optVal = replace.UnmarshalBytes(optVal) var table stack.Table switch replace.Name.String() { @@ -306,8 +304,7 @@ func parseMatchers(task *kernel.Task, filter stack.IPHeaderFilter, optVal []byte return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal)) } var match linux.XTEntryMatch - buf := optVal[:match.SizeBytes()] - match.UnmarshalUnsafe(buf) + match.UnmarshalUnsafe(optVal) nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match) // Check some invariants. diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index 6eff2ae65..d83d7e535 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -73,7 +73,7 @@ func (ownerMarshaler) unmarshal(task *kernel.Task, buf []byte, filter stack.IPHe // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.IPTOwnerInfo - matchData.UnmarshalUnsafe(buf[:linux.SizeOfIPTOwnerInfo]) + matchData.UnmarshalUnsafe(buf) nflog("parsed IPTOwnerInfo: %+v", matchData) var owner OwnerMatcher diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index b9c15daab..eaf601543 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -199,7 +199,7 @@ func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } var standardTarget linux.XTStandardTarget - standardTarget.UnmarshalUnsafe(buf[:standardTarget.SizeBytes()]) + standardTarget.UnmarshalUnsafe(buf) if standardTarget.Verdict < 0 { // A Verdict < 0 indicates a non-jump verdict. @@ -253,7 +253,6 @@ func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar return nil, syserr.ErrInvalidArgument } var errTgt linux.XTErrorTarget - buf = buf[:linux.SizeOfXTErrorTarget] errTgt.UnmarshalUnsafe(buf) // Error targets are used in 2 cases: @@ -316,7 +315,6 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( } var rt linux.XTRedirectTarget - buf = buf[:linux.SizeOfXTRedirectTarget] rt.UnmarshalUnsafe(buf) // Copy linux.XTRedirectTarget to stack.RedirectTarget. @@ -405,8 +403,7 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar } var natRange linux.NFNATRange - buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] - natRange.UnmarshalUnsafe(buf) + natRange.UnmarshalUnsafe(buf[linux.SizeOfXTEntryTarget:]) // We don't support port or address ranges. if natRange.MinAddr != natRange.MaxAddr { @@ -477,7 +474,6 @@ func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta } var st linux.XTSNATTarget - buf = buf[:linux.SizeOfXTSNATTarget] st.UnmarshalUnsafe(buf) // Copy linux.XTSNATTarget to stack.SNATTarget. @@ -557,8 +553,7 @@ func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta } var natRange linux.NFNATRange - buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] - natRange.UnmarshalUnsafe(buf) + natRange.UnmarshalUnsafe(buf[linux.SizeOfXTEntryTarget:]) // TODO(gvisor.dev/issue/5697): Support port or address ranges. if natRange.MinAddr != natRange.MaxAddr { @@ -621,7 +616,9 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.T return nil, syserr.ErrInvalidArgument } var target linux.XTEntryTarget - target.UnmarshalUnsafe(optVal[:target.SizeBytes()]) + // Do not advance optVal as targetMake.unmarshal() may unmarshal + // XTEntryTarget again but with some added fields. + target.UnmarshalUnsafe(optVal) return unmarshalTarget(target, filter, optVal) } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index e5b73a976..a621a6a16 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -59,7 +59,7 @@ func (tcpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderF // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.XTTCP - matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()]) + matchData.UnmarshalUnsafe(buf) nflog("parseMatchers: parsed XTTCP: %+v", matchData) if matchData.Option != 0 || diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index aa72ee70c..2ca854764 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -59,7 +59,7 @@ func (udpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderF // For alignment reasons, the match's total size may exceed what's // strictly necessary to hold matchData. var matchData linux.XTUDP - matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()]) + matchData.UnmarshalUnsafe(buf) nflog("parseMatchers: parsed XTUDP: %+v", matchData) if matchData.InverseFlags != 0 { diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go index 968968469..1604b2792 100644 --- a/pkg/sentry/socket/netlink/message_test.go +++ b/pkg/sentry/socket/netlink/message_test.go @@ -25,33 +25,14 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netlink" ) -type dummyNetlinkMsg struct { - marshal.StubMarshallable - Foo uint16 -} - -func (*dummyNetlinkMsg) SizeBytes() int { - return 2 -} - -func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) { - p := primitive.Uint16(m.Foo) - p.MarshalUnsafe(dst) -} - -func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) { - var p primitive.Uint16 - p.UnmarshalUnsafe(src) - m.Foo = uint16(p) -} - func TestParseMessage(t *testing.T) { + dummyNetlinkMsg := primitive.Uint16(0x3130) tests := []struct { desc string input []byte header linux.NetlinkMessageHeader - dataMsg *dummyNetlinkMsg + dataMsg marshal.Marshallable restLen int ok bool }{ @@ -72,9 +53,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 0, ok: true, }, @@ -96,9 +75,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 1, ok: true, }, @@ -119,9 +96,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 0, ok: true, }, @@ -143,9 +118,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 0, ok: true, }, @@ -199,11 +172,11 @@ func TestParseMessage(t *testing.T) { t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header) } - dataMsg := &dummyNetlinkMsg{} - _, dataOk := msg.GetData(dataMsg) + var dataMsg primitive.Uint16 + _, dataOk := msg.GetData(&dataMsg) if !dataOk { t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk) - } else if !reflect.DeepEqual(dataMsg, test.dataMsg) { + } else if !reflect.DeepEqual(&dataMsg, test.dataMsg) { t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg) } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 267155807..19c8f340d 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -223,7 +223,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) { } var sa linux.SockAddrNetlink - sa.UnmarshalUnsafe(b[:sa.SizeBytes()]) + sa.UnmarshalUnsafe(b) if sa.Family != linux.AF_NETLINK { return nil, syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index c35cf06f6..e38c4e5da 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -660,7 +660,7 @@ func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) < sockAddrLinkSize { return syserr.ErrInvalidArgument } - a.UnmarshalBytes(sockaddr[:sockAddrLinkSize]) + a.UnmarshalBytes(sockaddr) addr = tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), @@ -1839,7 +1839,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - v.UnmarshalBytes(optVal[:linux.SizeOfTimeval]) + v.UnmarshalBytes(optVal) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1852,7 +1852,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - v.UnmarshalBytes(optVal[:linux.SizeOfTimeval]) + v.UnmarshalBytes(optVal) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1883,7 +1883,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Linger - v.UnmarshalBytes(optVal[:linux.SizeOfLinger]) + v.UnmarshalBytes(optVal) if v != (linux.Linger{}) { socket.SetSockOptEmitUnimplementedEvent(t, name) @@ -2222,12 +2222,12 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR if len(optVal) >= inetMulticastRequestWithNICSize { var req linux.InetMulticastRequestWithNIC - req.UnmarshalUnsafe(optVal[:inetMulticastRequestWithNICSize]) + req.UnmarshalUnsafe(optVal) return req, nil } var req linux.InetMulticastRequestWithNIC - req.InetMulticastRequest.UnmarshalUnsafe(optVal[:inetMulticastRequestSize]) + req.InetMulticastRequest.UnmarshalUnsafe(optVal) return req, nil } @@ -2237,7 +2237,7 @@ func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syse } var req linux.Inet6MulticastRequest - req.UnmarshalUnsafe(optVal[:inet6MulticastRequestSize]) + req.UnmarshalUnsafe(optVal) return req, nil } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index fc5431eb1..01073df72 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -595,19 +595,19 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { switch family { case unix.AF_INET: var addr linux.SockAddrInet - addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(data) return &addr case unix.AF_INET6: var addr linux.SockAddrInet6 - addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(data) return &addr case unix.AF_UNIX: var addr linux.SockAddrUnix - addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(data) return &addr case unix.AF_NETLINK: var addr linux.SockAddrNetlink - addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(data) return &addr default: panic(fmt.Sprintf("Unsupported socket family %v", family)) @@ -738,7 +738,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInetSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - a.UnmarshalUnsafe(addr[:sockAddrInetSize]) + a.UnmarshalUnsafe(addr) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -751,7 +751,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInet6Size { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - a.UnmarshalUnsafe(addr[:sockAddrInet6Size]) + a.UnmarshalUnsafe(addr) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -767,7 +767,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrLinkSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - a.UnmarshalUnsafe(addr[:sockAddrLinkSize]) + a.UnmarshalUnsafe(addr) // TODO(https://gvisor.dev/issue/6530): Do not assume all interfaces have // an ethernet address. if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index f4aab25b0..b164d9107 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -161,12 +161,10 @@ var controlMessageType = map[int32]string{ linux.SO_TIMESTAMP: "SO_TIMESTAMP", } -func unmarshalControlMessageRights(src []byte) linux.ControlMessageRights { +func unmarshalControlMessageRights(src []byte) []primitive.Int32 { count := len(src) / linux.SizeOfControlMessageRight - cmr := make(linux.ControlMessageRights, count) - for i, _ := range cmr { - cmr[i] = int32(hostarch.ByteOrder.Uint32(src[i*linux.SizeOfControlMessageRight:])) - } + cmr := make([]primitive.Int32, count) + primitive.UnmarshalUnsafeInt32Slice(cmr, src) return cmr } @@ -182,14 +180,14 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) var strs []string - for i := 0; i < len(buf); { - if i+linux.SizeOfControlMessageHeader > len(buf) { + for len(buf) > 0 { + if linux.SizeOfControlMessageHeader > len(buf) { strs = append(strs, "{invalid control message (too short)}") break } var h linux.ControlMessageHeader - h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) + buf = h.UnmarshalUnsafe(buf) var skipData bool level := "SOL_SOCKET" @@ -204,7 +202,9 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) typ = fmt.Sprint(h.Type) } - if h.Length > uint64(len(buf)-i) { + width := t.Arch().Width() + length := int(h.Length) - linux.SizeOfControlMessageHeader + if length > len(buf) { strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, content extends beyond buffer}", level, @@ -214,9 +214,6 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) break } - i += linux.SizeOfControlMessageHeader - width := t.Arch().Width() - length := int(h.Length) - linux.SizeOfControlMessageHeader if length < 0 { strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, content too short}", @@ -229,78 +226,80 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) if skipData { strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length)) - i += bits.AlignUp(length, width) - continue - } - - switch h.Type { - case linux.SCM_RIGHTS: - rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) - fds := unmarshalControlMessageRights(buf[i : i+rightsSize]) - rights := make([]string, 0, len(fds)) - for _, fd := range fds { - rights = append(rights, fmt.Sprint(fd)) - } + } else { + switch h.Type { + case linux.SCM_RIGHTS: + rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) + fds := unmarshalControlMessageRights(buf[:rightsSize]) + rights := make([]string, 0, len(fds)) + for _, fd := range fds { + rights = append(rights, fmt.Sprint(fd)) + } - strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, content: %s}", - level, - typ, - h.Length, - strings.Join(rights, ","), - )) - - case linux.SCM_CREDENTIALS: - if length < linux.SizeOfControlMessageCredentials { strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, content too short}", + "{level=%s, type=%s, length=%d, content: %s}", level, typ, h.Length, + strings.Join(rights, ","), )) - break - } - var creds linux.ControlMessageCredentials - creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) + case linux.SCM_CREDENTIALS: + if length < linux.SizeOfControlMessageCredentials { + strs = append(strs, fmt.Sprintf( + "{level=%s, type=%s, length=%d, content too short}", + level, + typ, + h.Length, + )) + break + } - strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}", - level, - typ, - h.Length, - creds.PID, - creds.UID, - creds.GID, - )) + var creds linux.ControlMessageCredentials + creds.UnmarshalUnsafe(buf) - case linux.SO_TIMESTAMP: - if length < linux.SizeOfTimeval { strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, content too short}", + "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}", level, typ, h.Length, + creds.PID, + creds.UID, + creds.GID, )) - break - } - var tv linux.Timeval - tv.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) + case linux.SO_TIMESTAMP: + if length < linux.SizeOfTimeval { + strs = append(strs, fmt.Sprintf( + "{level=%s, type=%s, length=%d, content too short}", + level, + typ, + h.Length, + )) + break + } - strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}", - level, - typ, - h.Length, - tv.Sec, - tv.Usec, - )) + var tv linux.Timeval + tv.UnmarshalUnsafe(buf) - default: - panic("unreachable") + strs = append(strs, fmt.Sprintf( + "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}", + level, + typ, + h.Length, + tv.Sec, + tv.Usec, + )) + + default: + panic("unreachable") + } + } + if shift := bits.AlignUp(length, width); shift > len(buf) { + buf = buf[:0] + } else { + buf = buf[shift:] } - i += bits.AlignUp(length, width) } return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", ")) -- cgit v1.2.3