diff options
author | Ayush Ranjan <ayushranjan@google.com> | 2021-11-05 10:40:53 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-11-05 10:43:49 -0700 |
commit | ce4f4283badb6b07baf9f8e6d99e7a5fd15c92db (patch) | |
tree | 848dc50da62da59dc4a5781f9eb7461c58b71512 | |
parent | 822a647018adbd994114cb0dc8932f2853b805aa (diff) |
Make {Un}Marshal{Bytes/Unsafe} return remaining buffer.
Change marshal.Marshallable method signatures to return the remaining buffer.
This makes it easier to implement these method manually. Without this, we would
have to manually do buffer shifting which is error prone.
tools/go_marshal/test:benchmark test does not show change in performance.
Additionally fixed some marshalling bugs in fsimpl/fuse.
Updated multiple callpoints to get rid of redundant slice indexing work and
simplified code using this new signature.
Updates #6450
PiperOrigin-RevId: 407857019
41 files changed, 629 insertions, 752 deletions
diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go index 1070b457c..1112dadd6 100644 --- a/pkg/abi/linux/fuse.go +++ b/pkg/abi/linux/fuse.go @@ -352,6 +352,22 @@ type FUSEEntryOut struct { Attr FUSEAttr } +// CString represents a null terminated string which can be marshalled. +type CString string + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *CString) MarshalBytes(buf []byte) []byte { + copy(buf, *s) + buf[len(*s)] = 0 // null char + return buf[s.SizeBytes():] +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *CString) SizeBytes() int { + // 1 extra byte for null-terminated string. + return len(*s) + 1 +} + // FUSELookupIn is the request sent by the kernel to the daemon // to look up a file name. // @@ -360,18 +376,17 @@ type FUSELookupIn struct { marshal.StubMarshallable // Name is a file name to be looked up. - Name string + Name CString } // MarshalBytes serializes r.name to the dst buffer. -func (r *FUSELookupIn) MarshalBytes(buf []byte) { - copy(buf, r.Name) +func (r *FUSELookupIn) MarshalBytes(buf []byte) []byte { + return r.Name.MarshalBytes(buf) } // SizeBytes is the size of the memory representation of FUSELookupIn. -// 1 extra byte for null-terminated string. func (r *FUSELookupIn) SizeBytes() int { - return len(r.Name) + 1 + return r.Name.SizeBytes() } // MAX_NON_LFS indicates the maximum offset without large file support. @@ -530,19 +545,18 @@ type FUSECreateIn struct { CreateMeta FUSECreateMeta // Name is the name of the node to create. - Name string + Name CString } // MarshalBytes serializes r.CreateMeta and r.Name to the dst buffer. -func (r *FUSECreateIn) MarshalBytes(buf []byte) { - r.CreateMeta.MarshalBytes(buf[:r.CreateMeta.SizeBytes()]) - copy(buf[r.CreateMeta.SizeBytes():], r.Name) +func (r *FUSECreateIn) MarshalBytes(buf []byte) []byte { + buf = r.CreateMeta.MarshalBytes(buf) + return r.Name.MarshalBytes(buf) } // SizeBytes is the size of the memory representation of FUSECreateIn. -// 1 extra byte for null-terminated string. func (r *FUSECreateIn) SizeBytes() int { - return r.CreateMeta.SizeBytes() + len(r.Name) + 1 + return r.CreateMeta.SizeBytes() + r.Name.SizeBytes() } // FUSEMknodMeta contains all the static fields of FUSEMknodIn, @@ -573,19 +587,18 @@ type FUSEMknodIn struct { MknodMeta FUSEMknodMeta // Name is the name of the node to create. - Name string + Name CString } // MarshalBytes serializes r.MknodMeta and r.Name to the dst buffer. -func (r *FUSEMknodIn) MarshalBytes(buf []byte) { - r.MknodMeta.MarshalBytes(buf[:r.MknodMeta.SizeBytes()]) - copy(buf[r.MknodMeta.SizeBytes():], r.Name) +func (r *FUSEMknodIn) MarshalBytes(buf []byte) []byte { + buf = r.MknodMeta.MarshalBytes(buf) + return r.Name.MarshalBytes(buf) } // SizeBytes is the size of the memory representation of FUSEMknodIn. -// 1 extra byte for null-terminated string. func (r *FUSEMknodIn) SizeBytes() int { - return r.MknodMeta.SizeBytes() + len(r.Name) + 1 + return r.MknodMeta.SizeBytes() + r.Name.SizeBytes() } // FUSESymLinkIn is the request sent by the kernel to the daemon, @@ -596,30 +609,30 @@ type FUSESymLinkIn struct { marshal.StubMarshallable // Name of symlink to create. - Name string + Name CString // Target of the symlink. - Target string + Target CString } // MarshalBytes serializes r.Name and r.Target to the dst buffer. -// Left null-termination at end of r.Name and r.Target. -func (r *FUSESymLinkIn) MarshalBytes(buf []byte) { - copy(buf, r.Name) - copy(buf[len(r.Name)+1:], r.Target) +func (r *FUSESymLinkIn) MarshalBytes(buf []byte) []byte { + buf = r.Name.MarshalBytes(buf) + return r.Target.MarshalBytes(buf) } // SizeBytes is the size of the memory representation of FUSESymLinkIn. -// 2 extra bytes for null-terminated string. func (r *FUSESymLinkIn) SizeBytes() int { - return len(r.Name) + len(r.Target) + 2 + return r.Name.SizeBytes() + r.Target.SizeBytes() } // FUSEEmptyIn is used by operations without request body. type FUSEEmptyIn struct{ marshal.StubMarshallable } // MarshalBytes do nothing for marshal. -func (r *FUSEEmptyIn) MarshalBytes(buf []byte) {} +func (r *FUSEEmptyIn) MarshalBytes(buf []byte) []byte { + return buf +} // SizeBytes is 0 for empty request. func (r *FUSEEmptyIn) SizeBytes() int { @@ -649,19 +662,18 @@ type FUSEMkdirIn struct { MkdirMeta FUSEMkdirMeta // Name of the directory to create. - Name string + Name CString } // MarshalBytes serializes r.MkdirMeta and r.Name to the dst buffer. -func (r *FUSEMkdirIn) MarshalBytes(buf []byte) { - r.MkdirMeta.MarshalBytes(buf[:r.MkdirMeta.SizeBytes()]) - copy(buf[r.MkdirMeta.SizeBytes():], r.Name) +func (r *FUSEMkdirIn) MarshalBytes(buf []byte) []byte { + buf = r.MkdirMeta.MarshalBytes(buf) + return r.Name.MarshalBytes(buf) } // SizeBytes is the size of the memory representation of FUSEMkdirIn. -// 1 extra byte for null-terminated Name string. func (r *FUSEMkdirIn) SizeBytes() int { - return r.MkdirMeta.SizeBytes() + len(r.Name) + 1 + return r.MkdirMeta.SizeBytes() + r.Name.SizeBytes() } // FUSERmDirIn is the request sent by the kernel to the daemon @@ -672,17 +684,17 @@ type FUSERmDirIn struct { marshal.StubMarshallable // Name is a directory name to be removed. - Name string + Name CString } // MarshalBytes serializes r.name to the dst buffer. -func (r *FUSERmDirIn) MarshalBytes(buf []byte) { - copy(buf, r.Name) +func (r *FUSERmDirIn) MarshalBytes(buf []byte) []byte { + return r.Name.MarshalBytes(buf) } // SizeBytes is the size of the memory representation of FUSERmDirIn. func (r *FUSERmDirIn) SizeBytes() int { - return len(r.Name) + 1 + return r.Name.SizeBytes() } // FUSEDirents is a list of Dirents received from the FUSE daemon server. @@ -738,7 +750,7 @@ func (r *FUSEDirents) SizeBytes() int { } // UnmarshalBytes deserializes FUSEDirents from the src buffer. -func (r *FUSEDirents) UnmarshalBytes(src []byte) { +func (r *FUSEDirents) UnmarshalBytes(src []byte) []byte { for { if len(src) <= (*FUSEDirentMeta)(nil).SizeBytes() { break @@ -754,11 +766,10 @@ func (r *FUSEDirents) UnmarshalBytes(src []byte) { // to do this. Linux allocates 1 page to store all the dirents and then // simply reads them from the page. var dirent FUSEDirent - dirent.UnmarshalBytes(src) + src = dirent.UnmarshalBytes(src) r.Dirents = append(r.Dirents, &dirent) - - src = src[dirent.SizeBytes():] } + return src } // SizeBytes is the size of the memory representation of FUSEDirent. @@ -772,20 +783,20 @@ func (r *FUSEDirent) SizeBytes() int { } // UnmarshalBytes deserializes FUSEDirent from the src buffer. -func (r *FUSEDirent) UnmarshalBytes(src []byte) { - r.Meta.UnmarshalBytes(src) - src = src[r.Meta.SizeBytes():] +func (r *FUSEDirent) UnmarshalBytes(src []byte) []byte { + src = r.Meta.UnmarshalBytes(src) if r.Meta.NameLen > FUSE_NAME_MAX { // The name is too long and therefore invalid. We don't // need to unmarshal the name since it'll be thrown away. - return + return src } buf := make([]byte, r.Meta.NameLen) name := primitive.ByteSlice(buf) name.UnmarshalBytes(src[:r.Meta.NameLen]) r.Name = string(name) + return src[r.Meta.NameLen:] } // FATTR_* consts are the attribute flags defined in include/uapi/linux/fuse.h. @@ -863,17 +874,15 @@ type FUSEUnlinkIn struct { marshal.StubMarshallable // Name of the node to unlink. - Name string + Name CString } -// MarshalBytes serializes r.name to the dst buffer, which should -// have size len(r.Name) + 1 and last byte set to 0. -func (r *FUSEUnlinkIn) MarshalBytes(buf []byte) { - copy(buf, r.Name) +// MarshalBytes serializes r.name to the dst buffer. +func (r *FUSEUnlinkIn) MarshalBytes(buf []byte) []byte { + return r.Name.MarshalBytes(buf) } // SizeBytes is the size of the memory representation of FUSEUnlinkIn. -// 1 extra byte for null-terminated Name string. func (r *FUSEUnlinkIn) SizeBytes() int { - return len(r.Name) + 1 + return r.Name.SizeBytes() } diff --git a/pkg/abi/linux/msgqueue.go b/pkg/abi/linux/msgqueue.go index 0612a8214..6f8eb4dd9 100644 --- a/pkg/abi/linux/msgqueue.go +++ b/pkg/abi/linux/msgqueue.go @@ -82,15 +82,15 @@ func (b *MsgBuf) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (b *MsgBuf) MarshalBytes(dst []byte) { - b.Type.MarshalUnsafe(dst) - b.Text.MarshalBytes(dst[b.Type.SizeBytes():]) +func (b *MsgBuf) MarshalBytes(dst []byte) []byte { + dst = b.Type.MarshalUnsafe(dst) + return b.Text.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (b *MsgBuf) UnmarshalBytes(src []byte) { - b.Type.UnmarshalUnsafe(src) - b.Text.UnmarshalBytes(src[b.Type.SizeBytes():]) +func (b *MsgBuf) UnmarshalBytes(src []byte) []byte { + src = b.Type.UnmarshalUnsafe(src) + return b.Text.UnmarshalBytes(src) } // MsgInfo is equivelant to struct msginfo. Source: include/uapi/linux/msg.h diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 3fd05483a..1470a5578 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -144,15 +144,15 @@ func (ke *KernelIPTEntry) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (ke *KernelIPTEntry) MarshalBytes(dst []byte) { - ke.Entry.MarshalUnsafe(dst) - ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) +func (ke *KernelIPTEntry) MarshalBytes(dst []byte) []byte { + dst = ke.Entry.MarshalUnsafe(dst) + return ke.Elems.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) { - ke.Entry.UnmarshalUnsafe(src) - ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) +func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) []byte { + src = ke.Entry.UnmarshalUnsafe(src) + return ke.Elems.UnmarshalBytes(src) } var _ marshal.Marshallable = (*KernelIPTEntry)(nil) @@ -455,23 +455,21 @@ func (ke *KernelIPTGetEntries) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) { - ke.IPTGetEntries.MarshalUnsafe(dst) - marshalledUntil := ke.IPTGetEntries.SizeBytes() +func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) []byte { + dst = ke.IPTGetEntries.MarshalUnsafe(dst) for i := range ke.Entrytable { - ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) - marshalledUntil += ke.Entrytable[i].SizeBytes() + dst = ke.Entrytable[i].MarshalBytes(dst) } + return dst } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) { - ke.IPTGetEntries.UnmarshalUnsafe(src) - unmarshalledUntil := ke.IPTGetEntries.SizeBytes() +func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) []byte { + src = ke.IPTGetEntries.UnmarshalUnsafe(src) for i := range ke.Entrytable { - ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) - unmarshalledUntil += ke.Entrytable[i].SizeBytes() + src = ke.Entrytable[i].UnmarshalBytes(src) } + return src } var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil) diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go index f8c0e891e..aba0202ef 100644 --- a/pkg/abi/linux/netfilter_ipv6.go +++ b/pkg/abi/linux/netfilter_ipv6.go @@ -84,23 +84,21 @@ func (ke *KernelIP6TGetEntries) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (ke *KernelIP6TGetEntries) MarshalBytes(dst []byte) { - ke.IPTGetEntries.MarshalUnsafe(dst) - marshalledUntil := ke.IPTGetEntries.SizeBytes() +func (ke *KernelIP6TGetEntries) MarshalBytes(dst []byte) []byte { + dst = ke.IPTGetEntries.MarshalUnsafe(dst) for i := range ke.Entrytable { - ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) - marshalledUntil += ke.Entrytable[i].SizeBytes() + dst = ke.Entrytable[i].MarshalBytes(dst) } + return dst } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (ke *KernelIP6TGetEntries) UnmarshalBytes(src []byte) { - ke.IPTGetEntries.UnmarshalUnsafe(src) - unmarshalledUntil := ke.IPTGetEntries.SizeBytes() +func (ke *KernelIP6TGetEntries) UnmarshalBytes(src []byte) []byte { + src = ke.IPTGetEntries.UnmarshalUnsafe(src) for i := range ke.Entrytable { - ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) - unmarshalledUntil += ke.Entrytable[i].SizeBytes() + src = ke.Entrytable[i].UnmarshalBytes(src) } + return src } var _ marshal.Marshallable = (*KernelIP6TGetEntries)(nil) @@ -166,17 +164,19 @@ func (ke *KernelIP6TEntry) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (ke *KernelIP6TEntry) MarshalBytes(dst []byte) { - ke.Entry.MarshalUnsafe(dst) - ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) +func (ke *KernelIP6TEntry) MarshalBytes(dst []byte) []byte { + dst = ke.Entry.MarshalUnsafe(dst) + return ke.Elems.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (ke *KernelIP6TEntry) UnmarshalBytes(src []byte) { - ke.Entry.UnmarshalUnsafe(src) - ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) +func (ke *KernelIP6TEntry) UnmarshalBytes(src []byte) []byte { + src = ke.Entry.UnmarshalUnsafe(src) + return ke.Elems.UnmarshalBytes(src) } +var _ marshal.Marshallable = (*KernelIP6TEntry)(nil) + // IP6TIP contains information for matching a packet's IP header. // It corresponds to struct ip6t_ip6 in // include/uapi/linux/netfilter_ipv6/ip6_tables.h. diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index f60e42997..a31690a04 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -555,9 +555,6 @@ type ControlMessageIPv6PacketInfo struct { // ControlMessageCredentials struct. var SizeOfControlMessageCredentials = (*ControlMessageCredentials)(nil).SizeBytes() -// A ControlMessageRights is an SCM_RIGHTS socket control message. -type ControlMessageRights []int32 - // SizeOfControlMessageRight is the size of a single element in // ControlMessageRights. const SizeOfControlMessageRight = 4 diff --git a/pkg/lisafs/client.go b/pkg/lisafs/client.go index ccf1b9f72..e0f278b5c 100644 --- a/pkg/lisafs/client.go +++ b/pkg/lisafs/client.go @@ -313,7 +313,7 @@ func (c *Client) SyncFDs(ctx context.Context, fds []FDID) error { // implicit conversion to an interface leads to an allocation. // // Precondition: reqMarshal and respUnmarshal must be non-nil. -func (c *Client) SndRcvMessage(m MID, payloadLen uint32, reqMarshal func(dst []byte), respUnmarshal func(src []byte), respFDs []int) error { +func (c *Client) SndRcvMessage(m MID, payloadLen uint32, reqMarshal func(dst []byte) []byte, respUnmarshal func(src []byte) []byte, respFDs []int) error { if !c.IsSupported(m) { return unix.EOPNOTSUPP } diff --git a/pkg/lisafs/message.go b/pkg/lisafs/message.go index 722afd0be..0d7c30ce3 100644 --- a/pkg/lisafs/message.go +++ b/pkg/lisafs/message.go @@ -157,7 +157,6 @@ func MaxMessageSize() uint32 { // TODO(gvisor.dev/issue/6450): Once this is resolved: // * Update manual implementations and function signatures. // * Update RPC handlers and appropriate callers to handle errors correctly. -// * Update manual implementations to get rid of buffer shifting. // UID represents a user ID. // @@ -180,10 +179,10 @@ func (gid GID) Ok() bool { } // NoopMarshal is a noop implementation of marshal.Marshallable.MarshalBytes. -func NoopMarshal([]byte) {} +func NoopMarshal(b []byte) []byte { return b } // NoopUnmarshal is a noop implementation of marshal.Marshallable.UnmarshalBytes. -func NoopUnmarshal([]byte) {} +func NoopUnmarshal(b []byte) []byte { return b } // SizedString represents a string in memory. The marshalled string bytes are // preceded by a uint32 signifying the string length. @@ -195,21 +194,20 @@ func (s *SizedString) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (s *SizedString) MarshalBytes(dst []byte) { +func (s *SizedString) MarshalBytes(dst []byte) []byte { strLen := primitive.Uint32(len(*s)) - strLen.MarshalUnsafe(dst) - dst = dst[strLen.SizeBytes():] + dst = strLen.MarshalUnsafe(dst) // Copy without any allocation. - copy(dst[:strLen], *s) + return dst[copy(dst[:strLen], *s):] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (s *SizedString) UnmarshalBytes(src []byte) { +func (s *SizedString) UnmarshalBytes(src []byte) []byte { var strLen primitive.Uint32 - strLen.UnmarshalUnsafe(src) - src = src[strLen.SizeBytes():] + src = strLen.UnmarshalUnsafe(src) // Take the hit, this leads to an allocation + memcpy. No way around it. *s = SizedString(src[:strLen]) + return src[strLen:] } // StringArray represents an array of SizedStrings in memory. The marshalled @@ -227,22 +225,20 @@ func (s *StringArray) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (s *StringArray) MarshalBytes(dst []byte) { +func (s *StringArray) MarshalBytes(dst []byte) []byte { arrLen := primitive.Uint32(len(*s)) - arrLen.MarshalUnsafe(dst) - dst = dst[arrLen.SizeBytes():] + dst = arrLen.MarshalUnsafe(dst) for _, str := range *s { sstr := SizedString(str) - sstr.MarshalBytes(dst) - dst = dst[sstr.SizeBytes():] + dst = sstr.MarshalBytes(dst) } + return dst } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (s *StringArray) UnmarshalBytes(src []byte) { +func (s *StringArray) UnmarshalBytes(src []byte) []byte { var arrLen primitive.Uint32 - arrLen.UnmarshalUnsafe(src) - src = src[arrLen.SizeBytes():] + src = arrLen.UnmarshalUnsafe(src) if cap(*s) < int(arrLen) { *s = make([]string, arrLen) @@ -252,10 +248,10 @@ func (s *StringArray) UnmarshalBytes(src []byte) { for i := primitive.Uint32(0); i < arrLen; i++ { var sstr SizedString - sstr.UnmarshalBytes(src) - src = src[sstr.SizeBytes():] + src = sstr.UnmarshalBytes(src) (*s)[i] = string(sstr) } + return src } // Inode represents an inode on the remote filesystem. @@ -278,13 +274,13 @@ func (m *MountReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (m *MountReq) MarshalBytes(dst []byte) { - m.MountPath.MarshalBytes(dst) +func (m *MountReq) MarshalBytes(dst []byte) []byte { + return m.MountPath.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (m *MountReq) UnmarshalBytes(src []byte) { - m.MountPath.UnmarshalBytes(src) +func (m *MountReq) UnmarshalBytes(src []byte) []byte { + return m.MountPath.UnmarshalBytes(src) } // MountResp represents a Mount response. @@ -306,28 +302,30 @@ func (m *MountResp) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (m *MountResp) MarshalBytes(dst []byte) { - m.Root.MarshalUnsafe(dst) - dst = dst[m.Root.SizeBytes():] - m.MaxMessageSize.MarshalUnsafe(dst) - dst = dst[m.MaxMessageSize.SizeBytes():] +func (m *MountResp) MarshalBytes(dst []byte) []byte { + dst = m.Root.MarshalUnsafe(dst) + dst = m.MaxMessageSize.MarshalUnsafe(dst) numSupported := primitive.Uint16(len(m.SupportedMs)) - numSupported.MarshalBytes(dst) - dst = dst[numSupported.SizeBytes():] - MarshalUnsafeMIDSlice(m.SupportedMs, dst) + dst = numSupported.MarshalBytes(dst) + n, err := MarshalUnsafeMIDSlice(m.SupportedMs, dst) + if err != nil { + panic(err) + } + return dst[n:] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (m *MountResp) UnmarshalBytes(src []byte) { - m.Root.UnmarshalUnsafe(src) - src = src[m.Root.SizeBytes():] - m.MaxMessageSize.UnmarshalUnsafe(src) - src = src[m.MaxMessageSize.SizeBytes():] +func (m *MountResp) UnmarshalBytes(src []byte) []byte { + src = m.Root.UnmarshalUnsafe(src) + src = m.MaxMessageSize.UnmarshalUnsafe(src) var numSupported primitive.Uint16 - numSupported.UnmarshalBytes(src) - src = src[numSupported.SizeBytes():] + src = numSupported.UnmarshalBytes(src) m.SupportedMs = make([]MID, numSupported) - UnmarshalUnsafeMIDSlice(m.SupportedMs, src) + n, err := UnmarshalUnsafeMIDSlice(m.SupportedMs, src) + if err != nil { + panic(err) + } + return src[n:] } // ChannelResp is the response to the create channel request. @@ -391,17 +389,15 @@ func (w *WalkReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (w *WalkReq) MarshalBytes(dst []byte) { - w.DirFD.MarshalUnsafe(dst) - dst = dst[w.DirFD.SizeBytes():] - w.Path.MarshalBytes(dst) +func (w *WalkReq) MarshalBytes(dst []byte) []byte { + dst = w.DirFD.MarshalUnsafe(dst) + return w.Path.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (w *WalkReq) UnmarshalBytes(src []byte) { - w.DirFD.UnmarshalUnsafe(src) - src = src[w.DirFD.SizeBytes():] - w.Path.UnmarshalBytes(src) +func (w *WalkReq) UnmarshalBytes(src []byte) []byte { + src = w.DirFD.UnmarshalUnsafe(src) + return w.Path.UnmarshalBytes(src) } // WalkStatus is used to indicate the reason for partial/unsuccessful server @@ -438,32 +434,36 @@ func (w *WalkResp) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (w *WalkResp) MarshalBytes(dst []byte) { - w.Status.MarshalUnsafe(dst) - dst = dst[w.Status.SizeBytes():] +func (w *WalkResp) MarshalBytes(dst []byte) []byte { + dst = w.Status.MarshalUnsafe(dst) numInodes := primitive.Uint32(len(w.Inodes)) - numInodes.MarshalUnsafe(dst) - dst = dst[numInodes.SizeBytes():] + dst = numInodes.MarshalUnsafe(dst) - MarshalUnsafeInodeSlice(w.Inodes, dst) + n, err := MarshalUnsafeInodeSlice(w.Inodes, dst) + if err != nil { + panic(err) + } + return dst[n:] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (w *WalkResp) UnmarshalBytes(src []byte) { - w.Status.UnmarshalUnsafe(src) - src = src[w.Status.SizeBytes():] +func (w *WalkResp) UnmarshalBytes(src []byte) []byte { + src = w.Status.UnmarshalUnsafe(src) var numInodes primitive.Uint32 - numInodes.UnmarshalUnsafe(src) - src = src[numInodes.SizeBytes():] + src = numInodes.UnmarshalUnsafe(src) if cap(w.Inodes) < int(numInodes) { w.Inodes = make([]Inode, numInodes) } else { w.Inodes = w.Inodes[:numInodes] } - UnmarshalUnsafeInodeSlice(w.Inodes, src) + n, err := UnmarshalUnsafeInodeSlice(w.Inodes, src) + if err != nil { + panic(err) + } + return src[n:] } // WalkStatResp is used to communicate stat results for WalkStat. @@ -477,26 +477,32 @@ func (w *WalkStatResp) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (w *WalkStatResp) MarshalBytes(dst []byte) { +func (w *WalkStatResp) MarshalBytes(dst []byte) []byte { numStats := primitive.Uint32(len(w.Stats)) - numStats.MarshalUnsafe(dst) - dst = dst[numStats.SizeBytes():] + dst = numStats.MarshalUnsafe(dst) - linux.MarshalUnsafeStatxSlice(w.Stats, dst) + n, err := linux.MarshalUnsafeStatxSlice(w.Stats, dst) + if err != nil { + panic(err) + } + return dst[n:] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (w *WalkStatResp) UnmarshalBytes(src []byte) { +func (w *WalkStatResp) UnmarshalBytes(src []byte) []byte { var numStats primitive.Uint32 - numStats.UnmarshalUnsafe(src) - src = src[numStats.SizeBytes():] + src = numStats.UnmarshalUnsafe(src) if cap(w.Stats) < int(numStats) { w.Stats = make([]linux.Statx, numStats) } else { w.Stats = w.Stats[:numStats] } - linux.UnmarshalUnsafeStatxSlice(w.Stats, src) + n, err := linux.UnmarshalUnsafeStatxSlice(w.Stats, src) + if err != nil { + panic(err) + } + return src[n:] } // OpenAtReq is used to open existing FDs with the specified flags. @@ -536,21 +542,17 @@ func (o *OpenCreateAtReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (o *OpenCreateAtReq) MarshalBytes(dst []byte) { - o.createCommon.MarshalUnsafe(dst) - dst = dst[o.createCommon.SizeBytes():] - o.Name.MarshalBytes(dst) - dst = dst[o.Name.SizeBytes():] - o.Flags.MarshalUnsafe(dst) +func (o *OpenCreateAtReq) MarshalBytes(dst []byte) []byte { + dst = o.createCommon.MarshalUnsafe(dst) + dst = o.Name.MarshalBytes(dst) + return o.Flags.MarshalUnsafe(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (o *OpenCreateAtReq) UnmarshalBytes(src []byte) { - o.createCommon.UnmarshalUnsafe(src) - src = src[o.createCommon.SizeBytes():] - o.Name.UnmarshalBytes(src) - src = src[o.Name.SizeBytes():] - o.Flags.UnmarshalUnsafe(src) +func (o *OpenCreateAtReq) UnmarshalBytes(src []byte) []byte { + src = o.createCommon.UnmarshalUnsafe(src) + src = o.Name.UnmarshalBytes(src) + return o.Flags.UnmarshalUnsafe(src) } // OpenCreateAtResp is used to communicate successful OpenCreateAt results. @@ -573,24 +575,30 @@ func (f *FdArray) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (f *FdArray) MarshalBytes(dst []byte) { +func (f *FdArray) MarshalBytes(dst []byte) []byte { arrLen := primitive.Uint32(len(*f)) - arrLen.MarshalUnsafe(dst) - dst = dst[arrLen.SizeBytes():] - MarshalUnsafeFDIDSlice(*f, dst) + dst = arrLen.MarshalUnsafe(dst) + n, err := MarshalUnsafeFDIDSlice(*f, dst) + if err != nil { + panic(err) + } + return dst[n:] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (f *FdArray) UnmarshalBytes(src []byte) { +func (f *FdArray) UnmarshalBytes(src []byte) []byte { var arrLen primitive.Uint32 - arrLen.UnmarshalUnsafe(src) - src = src[arrLen.SizeBytes():] + src = arrLen.UnmarshalUnsafe(src) if cap(*f) < int(arrLen) { *f = make(FdArray, arrLen) } else { *f = (*f)[:arrLen] } - UnmarshalUnsafeFDIDSlice(*f, src) + n, err := UnmarshalUnsafeFDIDSlice(*f, src) + if err != nil { + panic(err) + } + return src[n:] } // CloseReq is used to close(2) FDs. @@ -604,13 +612,13 @@ func (c *CloseReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (c *CloseReq) MarshalBytes(dst []byte) { - c.FDs.MarshalBytes(dst) +func (c *CloseReq) MarshalBytes(dst []byte) []byte { + return c.FDs.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (c *CloseReq) UnmarshalBytes(src []byte) { - c.FDs.UnmarshalBytes(src) +func (c *CloseReq) UnmarshalBytes(src []byte) []byte { + return c.FDs.UnmarshalBytes(src) } // FsyncReq is used to fsync(2) FDs. @@ -624,13 +632,13 @@ func (f *FsyncReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (f *FsyncReq) MarshalBytes(dst []byte) { - f.FDs.MarshalBytes(dst) +func (f *FsyncReq) MarshalBytes(dst []byte) []byte { + return f.FDs.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (f *FsyncReq) UnmarshalBytes(src []byte) { - f.FDs.UnmarshalBytes(src) +func (f *FsyncReq) UnmarshalBytes(src []byte) []byte { + return f.FDs.UnmarshalBytes(src) } // PReadReq is used to pread(2) on an FD. @@ -654,20 +662,18 @@ func (r *PReadResp) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (r *PReadResp) MarshalBytes(dst []byte) { - r.NumBytes.MarshalUnsafe(dst) - dst = dst[r.NumBytes.SizeBytes():] - copy(dst[:r.NumBytes], r.Buf[:r.NumBytes]) +func (r *PReadResp) MarshalBytes(dst []byte) []byte { + dst = r.NumBytes.MarshalUnsafe(dst) + return dst[copy(dst[:r.NumBytes], r.Buf[:r.NumBytes]):] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (r *PReadResp) UnmarshalBytes(src []byte) { - r.NumBytes.UnmarshalUnsafe(src) - src = src[r.NumBytes.SizeBytes():] +func (r *PReadResp) UnmarshalBytes(src []byte) []byte { + src = r.NumBytes.UnmarshalUnsafe(src) // We expect the client to have already allocated r.Buf. r.Buf probably // (optimally) points to usermem. Directly copy into that. - copy(r.Buf[:r.NumBytes], src[:r.NumBytes]) + return src[copy(r.Buf[:r.NumBytes], src[:r.NumBytes]):] } // PWriteReq is used to pwrite(2) on an FD. @@ -684,28 +690,23 @@ func (w *PWriteReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (w *PWriteReq) MarshalBytes(dst []byte) { - w.Offset.MarshalUnsafe(dst) - dst = dst[w.Offset.SizeBytes():] - w.FD.MarshalUnsafe(dst) - dst = dst[w.FD.SizeBytes():] - w.NumBytes.MarshalUnsafe(dst) - dst = dst[w.NumBytes.SizeBytes():] - copy(dst[:w.NumBytes], w.Buf[:w.NumBytes]) +func (w *PWriteReq) MarshalBytes(dst []byte) []byte { + dst = w.Offset.MarshalUnsafe(dst) + dst = w.FD.MarshalUnsafe(dst) + dst = w.NumBytes.MarshalUnsafe(dst) + return dst[copy(dst[:w.NumBytes], w.Buf[:w.NumBytes]):] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (w *PWriteReq) UnmarshalBytes(src []byte) { - w.Offset.UnmarshalUnsafe(src) - src = src[w.Offset.SizeBytes():] - w.FD.UnmarshalUnsafe(src) - src = src[w.FD.SizeBytes():] - w.NumBytes.UnmarshalUnsafe(src) - src = src[w.NumBytes.SizeBytes():] +func (w *PWriteReq) UnmarshalBytes(src []byte) []byte { + src = w.Offset.UnmarshalUnsafe(src) + src = w.FD.UnmarshalUnsafe(src) + src = w.NumBytes.UnmarshalUnsafe(src) // This is an optimization. Assuming that the server is making this call, it // is safe to just point to src rather than allocating and copying. w.Buf = src[:w.NumBytes] + return src[w.NumBytes:] } // PWriteResp is used to return the result of pwrite(2). @@ -727,17 +728,15 @@ func (m *MkdirAtReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (m *MkdirAtReq) MarshalBytes(dst []byte) { - m.createCommon.MarshalUnsafe(dst) - dst = dst[m.createCommon.SizeBytes():] - m.Name.MarshalBytes(dst) +func (m *MkdirAtReq) MarshalBytes(dst []byte) []byte { + dst = m.createCommon.MarshalUnsafe(dst) + return m.Name.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (m *MkdirAtReq) UnmarshalBytes(src []byte) { - m.createCommon.UnmarshalUnsafe(src) - src = src[m.createCommon.SizeBytes():] - m.Name.UnmarshalBytes(src) +func (m *MkdirAtReq) UnmarshalBytes(src []byte) []byte { + src = m.createCommon.UnmarshalUnsafe(src) + return m.Name.UnmarshalBytes(src) } // MkdirAtResp is the response to a successful MkdirAt request. @@ -761,25 +760,19 @@ func (m *MknodAtReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (m *MknodAtReq) MarshalBytes(dst []byte) { - m.createCommon.MarshalUnsafe(dst) - dst = dst[m.createCommon.SizeBytes():] - m.Name.MarshalBytes(dst) - dst = dst[m.Name.SizeBytes():] - m.Minor.MarshalUnsafe(dst) - dst = dst[m.Minor.SizeBytes():] - m.Major.MarshalUnsafe(dst) +func (m *MknodAtReq) MarshalBytes(dst []byte) []byte { + dst = m.createCommon.MarshalUnsafe(dst) + dst = m.Name.MarshalBytes(dst) + dst = m.Minor.MarshalUnsafe(dst) + return m.Major.MarshalUnsafe(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (m *MknodAtReq) UnmarshalBytes(src []byte) { - m.createCommon.UnmarshalUnsafe(src) - src = src[m.createCommon.SizeBytes():] - m.Name.UnmarshalBytes(src) - src = src[m.Name.SizeBytes():] - m.Minor.UnmarshalUnsafe(src) - src = src[m.Minor.SizeBytes():] - m.Major.UnmarshalUnsafe(src) +func (m *MknodAtReq) UnmarshalBytes(src []byte) []byte { + src = m.createCommon.UnmarshalUnsafe(src) + src = m.Name.UnmarshalBytes(src) + src = m.Minor.UnmarshalUnsafe(src) + return m.Major.UnmarshalUnsafe(src) } // MknodAtResp is the response to a successful MknodAt request. @@ -804,29 +797,21 @@ func (s *SymlinkAtReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (s *SymlinkAtReq) MarshalBytes(dst []byte) { - s.DirFD.MarshalUnsafe(dst) - dst = dst[s.DirFD.SizeBytes():] - s.Name.MarshalBytes(dst) - dst = dst[s.Name.SizeBytes():] - s.Target.MarshalBytes(dst) - dst = dst[s.Target.SizeBytes():] - s.UID.MarshalUnsafe(dst) - dst = dst[s.UID.SizeBytes():] - s.GID.MarshalUnsafe(dst) +func (s *SymlinkAtReq) MarshalBytes(dst []byte) []byte { + dst = s.DirFD.MarshalUnsafe(dst) + dst = s.Name.MarshalBytes(dst) + dst = s.Target.MarshalBytes(dst) + dst = s.UID.MarshalUnsafe(dst) + return s.GID.MarshalUnsafe(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (s *SymlinkAtReq) UnmarshalBytes(src []byte) { - s.DirFD.UnmarshalUnsafe(src) - src = src[s.DirFD.SizeBytes():] - s.Name.UnmarshalBytes(src) - src = src[s.Name.SizeBytes():] - s.Target.UnmarshalBytes(src) - src = src[s.Target.SizeBytes():] - s.UID.UnmarshalUnsafe(src) - src = src[s.UID.SizeBytes():] - s.GID.UnmarshalUnsafe(src) +func (s *SymlinkAtReq) UnmarshalBytes(src []byte) []byte { + src = s.DirFD.UnmarshalUnsafe(src) + src = s.Name.UnmarshalBytes(src) + src = s.Target.UnmarshalBytes(src) + src = s.UID.UnmarshalUnsafe(src) + return s.GID.UnmarshalUnsafe(src) } // SymlinkAtResp is the response to a successful SymlinkAt request. @@ -849,21 +834,17 @@ func (l *LinkAtReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (l *LinkAtReq) MarshalBytes(dst []byte) { - l.DirFD.MarshalUnsafe(dst) - dst = dst[l.DirFD.SizeBytes():] - l.Target.MarshalUnsafe(dst) - dst = dst[l.Target.SizeBytes():] - l.Name.MarshalBytes(dst) +func (l *LinkAtReq) MarshalBytes(dst []byte) []byte { + dst = l.DirFD.MarshalUnsafe(dst) + dst = l.Target.MarshalUnsafe(dst) + return l.Name.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (l *LinkAtReq) UnmarshalBytes(src []byte) { - l.DirFD.UnmarshalUnsafe(src) - src = src[l.DirFD.SizeBytes():] - l.Target.UnmarshalUnsafe(src) - src = src[l.Target.SizeBytes():] - l.Name.UnmarshalBytes(src) +func (l *LinkAtReq) UnmarshalBytes(src []byte) []byte { + src = l.DirFD.UnmarshalUnsafe(src) + src = l.Target.UnmarshalUnsafe(src) + return l.Name.UnmarshalBytes(src) } // LinkAtResp is used to respond to a successful LinkAt request. @@ -923,13 +904,13 @@ func (r *ReadLinkAtResp) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (r *ReadLinkAtResp) MarshalBytes(dst []byte) { - r.Target.MarshalBytes(dst) +func (r *ReadLinkAtResp) MarshalBytes(dst []byte) []byte { + return r.Target.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (r *ReadLinkAtResp) UnmarshalBytes(src []byte) { - r.Target.UnmarshalBytes(src) +func (r *ReadLinkAtResp) UnmarshalBytes(src []byte) []byte { + return r.Target.UnmarshalBytes(src) } // FlushReq is used to make Flush requests. @@ -963,21 +944,17 @@ func (u *UnlinkAtReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (u *UnlinkAtReq) MarshalBytes(dst []byte) { - u.DirFD.MarshalUnsafe(dst) - dst = dst[u.DirFD.SizeBytes():] - u.Name.MarshalBytes(dst) - dst = dst[u.Name.SizeBytes():] - u.Flags.MarshalUnsafe(dst) +func (u *UnlinkAtReq) MarshalBytes(dst []byte) []byte { + dst = u.DirFD.MarshalUnsafe(dst) + dst = u.Name.MarshalBytes(dst) + return u.Flags.MarshalUnsafe(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (u *UnlinkAtReq) UnmarshalBytes(src []byte) { - u.DirFD.UnmarshalUnsafe(src) - src = src[u.DirFD.SizeBytes():] - u.Name.UnmarshalBytes(src) - src = src[u.Name.SizeBytes():] - u.Flags.UnmarshalUnsafe(src) +func (u *UnlinkAtReq) UnmarshalBytes(src []byte) []byte { + src = u.DirFD.UnmarshalUnsafe(src) + src = u.Name.UnmarshalBytes(src) + return u.Flags.UnmarshalUnsafe(src) } // RenameAtReq is used to make Rename requests. Note that the request takes in @@ -994,21 +971,17 @@ func (r *RenameAtReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (r *RenameAtReq) MarshalBytes(dst []byte) { - r.Renamed.MarshalUnsafe(dst) - dst = dst[r.Renamed.SizeBytes():] - r.NewDir.MarshalUnsafe(dst) - dst = dst[r.NewDir.SizeBytes():] - r.NewName.MarshalBytes(dst) +func (r *RenameAtReq) MarshalBytes(dst []byte) []byte { + dst = r.Renamed.MarshalUnsafe(dst) + dst = r.NewDir.MarshalUnsafe(dst) + return r.NewName.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (r *RenameAtReq) UnmarshalBytes(src []byte) { - r.Renamed.UnmarshalUnsafe(src) - src = src[r.Renamed.SizeBytes():] - r.NewDir.UnmarshalUnsafe(src) - src = src[r.NewDir.SizeBytes():] - r.NewName.UnmarshalBytes(src) +func (r *RenameAtReq) UnmarshalBytes(src []byte) []byte { + src = r.Renamed.UnmarshalUnsafe(src) + src = r.NewDir.UnmarshalUnsafe(src) + return r.NewName.UnmarshalBytes(src) } // Getdents64Req is used to make Getdents64 requests. @@ -1039,33 +1012,23 @@ func (d *Dirent64) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (d *Dirent64) MarshalBytes(dst []byte) { - d.Ino.MarshalUnsafe(dst) - dst = dst[d.Ino.SizeBytes():] - d.DevMinor.MarshalUnsafe(dst) - dst = dst[d.DevMinor.SizeBytes():] - d.DevMajor.MarshalUnsafe(dst) - dst = dst[d.DevMajor.SizeBytes():] - d.Off.MarshalUnsafe(dst) - dst = dst[d.Off.SizeBytes():] - d.Type.MarshalUnsafe(dst) - dst = dst[d.Type.SizeBytes():] - d.Name.MarshalBytes(dst) +func (d *Dirent64) MarshalBytes(dst []byte) []byte { + dst = d.Ino.MarshalUnsafe(dst) + dst = d.DevMinor.MarshalUnsafe(dst) + dst = d.DevMajor.MarshalUnsafe(dst) + dst = d.Off.MarshalUnsafe(dst) + dst = d.Type.MarshalUnsafe(dst) + return d.Name.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (d *Dirent64) UnmarshalBytes(src []byte) { - d.Ino.UnmarshalUnsafe(src) - src = src[d.Ino.SizeBytes():] - d.DevMinor.UnmarshalUnsafe(src) - src = src[d.DevMinor.SizeBytes():] - d.DevMajor.UnmarshalUnsafe(src) - src = src[d.DevMajor.SizeBytes():] - d.Off.UnmarshalUnsafe(src) - src = src[d.Off.SizeBytes():] - d.Type.UnmarshalUnsafe(src) - src = src[d.Type.SizeBytes():] - d.Name.UnmarshalBytes(src) +func (d *Dirent64) UnmarshalBytes(src []byte) []byte { + src = d.Ino.UnmarshalUnsafe(src) + src = d.DevMinor.UnmarshalUnsafe(src) + src = d.DevMajor.UnmarshalUnsafe(src) + src = d.Off.UnmarshalUnsafe(src) + src = d.Type.UnmarshalUnsafe(src) + return d.Name.UnmarshalBytes(src) } // Getdents64Resp is used to communicate getdents64 results. @@ -1083,31 +1046,29 @@ func (g *Getdents64Resp) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (g *Getdents64Resp) MarshalBytes(dst []byte) { +func (g *Getdents64Resp) MarshalBytes(dst []byte) []byte { numDirents := primitive.Uint32(len(g.Dirents)) - numDirents.MarshalUnsafe(dst) - dst = dst[numDirents.SizeBytes():] + dst = numDirents.MarshalUnsafe(dst) for i := range g.Dirents { - g.Dirents[i].MarshalBytes(dst) - dst = dst[g.Dirents[i].SizeBytes():] + dst = g.Dirents[i].MarshalBytes(dst) } + return dst } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (g *Getdents64Resp) UnmarshalBytes(src []byte) { +func (g *Getdents64Resp) UnmarshalBytes(src []byte) []byte { var numDirents primitive.Uint32 - numDirents.UnmarshalUnsafe(src) + src = numDirents.UnmarshalUnsafe(src) if cap(g.Dirents) < int(numDirents) { g.Dirents = make([]Dirent64, numDirents) } else { g.Dirents = g.Dirents[:numDirents] } - src = src[numDirents.SizeBytes():] for i := range g.Dirents { - g.Dirents[i].UnmarshalBytes(src) - src = src[g.Dirents[i].SizeBytes():] + src = g.Dirents[i].UnmarshalBytes(src) } + return src } // FGetXattrReq is used to make FGetXattr requests. The response to this is @@ -1124,21 +1085,17 @@ func (g *FGetXattrReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (g *FGetXattrReq) MarshalBytes(dst []byte) { - g.FD.MarshalUnsafe(dst) - dst = dst[g.FD.SizeBytes():] - g.BufSize.MarshalUnsafe(dst) - dst = dst[g.BufSize.SizeBytes():] - g.Name.MarshalBytes(dst) +func (g *FGetXattrReq) MarshalBytes(dst []byte) []byte { + dst = g.FD.MarshalUnsafe(dst) + dst = g.BufSize.MarshalUnsafe(dst) + return g.Name.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (g *FGetXattrReq) UnmarshalBytes(src []byte) { - g.FD.UnmarshalUnsafe(src) - src = src[g.FD.SizeBytes():] - g.BufSize.UnmarshalUnsafe(src) - src = src[g.BufSize.SizeBytes():] - g.Name.UnmarshalBytes(src) +func (g *FGetXattrReq) UnmarshalBytes(src []byte) []byte { + src = g.FD.UnmarshalUnsafe(src) + src = g.BufSize.UnmarshalUnsafe(src) + return g.Name.UnmarshalBytes(src) } // FGetXattrResp is used to respond to FGetXattr request. @@ -1152,13 +1109,13 @@ func (g *FGetXattrResp) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (g *FGetXattrResp) MarshalBytes(dst []byte) { - g.Value.MarshalBytes(dst) +func (g *FGetXattrResp) MarshalBytes(dst []byte) []byte { + return g.Value.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (g *FGetXattrResp) UnmarshalBytes(src []byte) { - g.Value.UnmarshalBytes(src) +func (g *FGetXattrResp) UnmarshalBytes(src []byte) []byte { + return g.Value.UnmarshalBytes(src) } // FSetXattrReq is used to make FSetXattr requests. It has no response. @@ -1175,25 +1132,19 @@ func (s *FSetXattrReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (s *FSetXattrReq) MarshalBytes(dst []byte) { - s.FD.MarshalUnsafe(dst) - dst = dst[s.FD.SizeBytes():] - s.Flags.MarshalUnsafe(dst) - dst = dst[s.Flags.SizeBytes():] - s.Name.MarshalBytes(dst) - dst = dst[s.Name.SizeBytes():] - s.Value.MarshalBytes(dst) +func (s *FSetXattrReq) MarshalBytes(dst []byte) []byte { + dst = s.FD.MarshalUnsafe(dst) + dst = s.Flags.MarshalUnsafe(dst) + dst = s.Name.MarshalBytes(dst) + return s.Value.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (s *FSetXattrReq) UnmarshalBytes(src []byte) { - s.FD.UnmarshalUnsafe(src) - src = src[s.FD.SizeBytes():] - s.Flags.UnmarshalUnsafe(src) - src = src[s.Flags.SizeBytes():] - s.Name.UnmarshalBytes(src) - src = src[s.Name.SizeBytes():] - s.Value.UnmarshalBytes(src) +func (s *FSetXattrReq) UnmarshalBytes(src []byte) []byte { + src = s.FD.UnmarshalUnsafe(src) + src = s.Flags.UnmarshalUnsafe(src) + src = s.Name.UnmarshalBytes(src) + return s.Value.UnmarshalBytes(src) } // FRemoveXattrReq is used to make FRemoveXattr requests. It has no response. @@ -1208,17 +1159,15 @@ func (r *FRemoveXattrReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (r *FRemoveXattrReq) MarshalBytes(dst []byte) { - r.FD.MarshalUnsafe(dst) - dst = dst[r.FD.SizeBytes():] - r.Name.MarshalBytes(dst) +func (r *FRemoveXattrReq) MarshalBytes(dst []byte) []byte { + dst = r.FD.MarshalUnsafe(dst) + return r.Name.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (r *FRemoveXattrReq) UnmarshalBytes(src []byte) { - r.FD.UnmarshalUnsafe(src) - src = src[r.FD.SizeBytes():] - r.Name.UnmarshalBytes(src) +func (r *FRemoveXattrReq) UnmarshalBytes(src []byte) []byte { + src = r.FD.UnmarshalUnsafe(src) + return r.Name.UnmarshalBytes(src) } // FListXattrReq is used to make FListXattr requests. @@ -1241,11 +1190,11 @@ func (l *FListXattrResp) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (l *FListXattrResp) MarshalBytes(dst []byte) { - l.Xattrs.MarshalBytes(dst) +func (l *FListXattrResp) MarshalBytes(dst []byte) []byte { + return l.Xattrs.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (l *FListXattrResp) UnmarshalBytes(src []byte) { - l.Xattrs.UnmarshalBytes(src) +func (l *FListXattrResp) UnmarshalBytes(src []byte) []byte { + return l.Xattrs.UnmarshalBytes(src) } diff --git a/pkg/lisafs/sample_message.go b/pkg/lisafs/sample_message.go index 3868dfa08..745736b6d 100644 --- a/pkg/lisafs/sample_message.go +++ b/pkg/lisafs/sample_message.go @@ -53,18 +53,24 @@ func (m *MsgDynamic) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (m *MsgDynamic) MarshalBytes(dst []byte) { - m.N.MarshalUnsafe(dst) - dst = dst[m.N.SizeBytes():] - MarshalUnsafeMsg1Slice(m.Arr, dst) +func (m *MsgDynamic) MarshalBytes(dst []byte) []byte { + dst = m.N.MarshalUnsafe(dst) + n, err := MarshalUnsafeMsg1Slice(m.Arr, dst) + if err != nil { + panic(err) + } + return dst[n:] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (m *MsgDynamic) UnmarshalBytes(src []byte) { - m.N.UnmarshalUnsafe(src) - src = src[m.N.SizeBytes():] +func (m *MsgDynamic) UnmarshalBytes(src []byte) []byte { + src = m.N.UnmarshalUnsafe(src) m.Arr = make([]MsgSimple, m.N) - UnmarshalUnsafeMsg1Slice(m.Arr, src) + n, err := UnmarshalUnsafeMsg1Slice(m.Arr, src) + if err != nil { + panic(err) + } + return src[n:] } // Randomize randomizes the contents of m. @@ -90,21 +96,18 @@ func (v *P9Version) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (v *P9Version) MarshalBytes(dst []byte) { - v.MSize.MarshalUnsafe(dst) - dst = dst[v.MSize.SizeBytes():] +func (v *P9Version) MarshalBytes(dst []byte) []byte { + dst = v.MSize.MarshalUnsafe(dst) versionLen := primitive.Uint16(len(v.Version)) - versionLen.MarshalUnsafe(dst) - dst = dst[versionLen.SizeBytes():] - copy(dst, v.Version) + dst = versionLen.MarshalUnsafe(dst) + return dst[copy(dst, v.Version):] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (v *P9Version) UnmarshalBytes(src []byte) { - v.MSize.UnmarshalUnsafe(src) - src = src[v.MSize.SizeBytes():] +func (v *P9Version) UnmarshalBytes(src []byte) []byte { + src = v.MSize.UnmarshalUnsafe(src) var versionLen primitive.Uint16 - versionLen.UnmarshalUnsafe(src) - src = src[versionLen.SizeBytes():] + src = versionLen.UnmarshalUnsafe(src) v.Version = string(src[:versionLen]) + return src[versionLen:] } diff --git a/pkg/marshal/marshal.go b/pkg/marshal/marshal.go index 7da450ce8..9e34eae80 100644 --- a/pkg/marshal/marshal.go +++ b/pkg/marshal/marshal.go @@ -59,13 +59,15 @@ type Marshallable interface { // likely make use of the type of these fields). SizeBytes() int - // MarshalBytes serializes a copy of a type to dst. + // MarshalBytes serializes a copy of a type to dst and returns the remaining + // buffer. // Precondition: dst must be at least SizeBytes() in length. - MarshalBytes(dst []byte) + MarshalBytes(dst []byte) []byte - // UnmarshalBytes deserializes a type from src. + // UnmarshalBytes deserializes a type from src and returns the remaining + // buffer. // Precondition: src must be at least SizeBytes() in length. - UnmarshalBytes(src []byte) + UnmarshalBytes(src []byte) []byte // Packed returns true if the marshalled size of the type is the same as the // size it occupies in memory. This happens when the type has no fields @@ -86,7 +88,7 @@ type Marshallable interface { // return false, MarshalUnsafe should fall back to the safer but slower // MarshalBytes. // Precondition: dst must be at least SizeBytes() in length. - MarshalUnsafe(dst []byte) + MarshalUnsafe(dst []byte) []byte // UnmarshalUnsafe deserializes a type by directly copying to the underlying // memory allocated for the object by the runtime. @@ -96,7 +98,7 @@ type Marshallable interface { // UnmarshalUnsafe should fall back to the safer but slower unmarshal // mechanism implemented in UnmarshalBytes. // Precondition: src must be at least SizeBytes() in length. - UnmarshalUnsafe(src []byte) + UnmarshalUnsafe(src []byte) []byte // CopyIn deserializes a Marshallable type from a task's memory. This may // only be called from a task goroutine. This is more efficient than calling diff --git a/pkg/marshal/marshal_impl_util.go b/pkg/marshal/marshal_impl_util.go index 9e6a6fa29..6c1cf7a4c 100644 --- a/pkg/marshal/marshal_impl_util.go +++ b/pkg/marshal/marshal_impl_util.go @@ -38,12 +38,12 @@ func (StubMarshallable) SizeBytes() int { } // MarshalBytes implements Marshallable.MarshalBytes. -func (StubMarshallable) MarshalBytes(dst []byte) { +func (StubMarshallable) MarshalBytes(dst []byte) []byte { panic("Please implement your own MarshalBytes function") } // UnmarshalBytes implements Marshallable.UnmarshalBytes. -func (StubMarshallable) UnmarshalBytes(src []byte) { +func (StubMarshallable) UnmarshalBytes(src []byte) []byte { panic("Please implement your own UnmarshalBytes function") } @@ -53,12 +53,12 @@ func (StubMarshallable) Packed() bool { } // MarshalUnsafe implements Marshallable.MarshalUnsafe. -func (StubMarshallable) MarshalUnsafe(dst []byte) { +func (StubMarshallable) MarshalUnsafe(dst []byte) []byte { panic("Please implement your own MarshalUnsafe function") } // UnmarshalUnsafe implements Marshallable.UnmarshalUnsafe. -func (StubMarshallable) UnmarshalUnsafe(src []byte) { +func (StubMarshallable) UnmarshalUnsafe(src []byte) []byte { panic("Please implement your own UnmarshalUnsafe function") } diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go index 1c49cf082..7ece26933 100644 --- a/pkg/marshal/primitive/primitive.go +++ b/pkg/marshal/primitive/primitive.go @@ -76,13 +76,13 @@ func (b *ByteSlice) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (b *ByteSlice) MarshalBytes(dst []byte) { - copy(dst, *b) +func (b *ByteSlice) MarshalBytes(dst []byte) []byte { + return dst[copy(dst, *b):] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (b *ByteSlice) UnmarshalBytes(src []byte) { - copy(*b, src) +func (b *ByteSlice) UnmarshalBytes(src []byte) []byte { + return src[copy(*b, src):] } // Packed implements marshal.Marshallable.Packed. @@ -91,13 +91,13 @@ func (b *ByteSlice) Packed() bool { } // MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. -func (b *ByteSlice) MarshalUnsafe(dst []byte) { - b.MarshalBytes(dst) +func (b *ByteSlice) MarshalUnsafe(dst []byte) []byte { + return b.MarshalBytes(dst) } // UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. -func (b *ByteSlice) UnmarshalUnsafe(src []byte) { - b.UnmarshalBytes(src) +func (b *ByteSlice) UnmarshalUnsafe(src []byte) []byte { + return b.UnmarshalBytes(src) } // CopyIn implements marshal.Marshallable.CopyIn. diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 05c4fbeb2..18497a880 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -77,8 +77,7 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go index 1fddd858e..d98d2832b 100644 --- a/pkg/sentry/fsimpl/fuse/connection_test.go +++ b/pkg/sentry/fsimpl/fuse/connection_test.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ) @@ -69,14 +70,10 @@ func TestConnectionAbort(t *testing.T) { t.Fatalf("newTestConnection: %v", err) } - testObj := &testPayload{ - data: rand.Uint32(), - } - var futNormal []*futureResponse - + testObj := primitive.Uint32(rand.Uint32()) for i := 0; i < int(numRequests); i++ { - req := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj) + req := conn.NewRequest(creds, uint32(i), uint64(i), 0, &testObj) fut, err := conn.callFutureLocked(task, req) if err != nil { t.Fatalf("callFutureLocked failed: %v", err) @@ -102,7 +99,7 @@ func TestConnectionAbort(t *testing.T) { } // After abort, Call() should return directly with ENOTCONN. - req := conn.NewRequest(creds, 0, 0, 0, testObj) + req := conn.NewRequest(creds, 0, 0, 0, &testObj) _, err = conn.Call(task, req) if !linuxerr.Equals(linuxerr.ENOTCONN, err) { t.Fatalf("Incorrect error code received for Call() after connection aborted") diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go index 8951b5ba8..13b32fc7c 100644 --- a/pkg/sentry/fsimpl/fuse/dev_test.go +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -215,11 +216,9 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con if err != nil { t.Fatal(err) } - testObj := &testPayload{ - data: rand.Uint32(), - } - req := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj) + testObj := primitive.Uint32(rand.Uint32()) + req := conn.NewRequest(creds, pid, inode, echoTestOpcode, &testObj) // Queue up a request. // Analogous to Call except it doesn't block on the task. @@ -232,7 +231,7 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con t.Fatalf("Server responded with an error: %v", err) } - var respTestPayload testPayload + var respTestPayload primitive.Uint32 if err := resp.UnmarshalPayload(&respTestPayload); err != nil { t.Fatalf("Unmarshalling payload error: %v", err) } @@ -242,8 +241,8 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con req.hdr.Unique, resp.hdr.Unique) } - if respTestPayload.data != testObj.data { - t.Fatalf("read incorrect data. Data expected: %v, but got %v", testObj.data, respTestPayload.data) + if respTestPayload != testObj { + t.Fatalf("read incorrect data. Data expected: %d, but got %d", testObj, respTestPayload) } } @@ -256,8 +255,8 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F // Create the tasks that the server will be using. tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) - var readPayload testPayload + var readPayload primitive.Uint32 serverTask, err := testutil.CreateTask(s.Ctx, "fuse-server", tc, s.MntNs, s.Root, s.Root) if err != nil { t.Fatal(err) @@ -291,8 +290,8 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F } var readFUSEHeaderIn linux.FUSEHeaderIn - readFUSEHeaderIn.UnmarshalUnsafe(inBuf[:inHdrLen]) - readPayload.UnmarshalUnsafe(inBuf[inHdrLen : inHdrLen+payloadLen]) + inBuf = readFUSEHeaderIn.UnmarshalUnsafe(inBuf) + readPayload.UnmarshalUnsafe(inBuf) if readFUSEHeaderIn.Opcode != echoTestOpcode { t.Fatalf("read incorrect data. Header: %v, Payload: %v", readFUSEHeaderIn, readPayload) diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index af16098d2..00b520c31 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -489,7 +489,7 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr // Lookup implements kernfs.Inode.Lookup. func (i *inode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) { - in := linux.FUSELookupIn{Name: name} + in := linux.FUSELookupIn{Name: linux.CString(name)} return i.newEntry(ctx, name, 0, linux.FUSE_LOOKUP, &in) } @@ -520,7 +520,7 @@ func (i *inode) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) Mode: uint32(opts.Mode) | linux.S_IFREG, Umask: uint32(kernelTask.FSContext().Umask()), }, - Name: name, + Name: linux.CString(name), } return i.newEntry(ctx, name, linux.S_IFREG, linux.FUSE_CREATE, &in) } @@ -533,7 +533,7 @@ func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) Rdev: linux.MakeDeviceID(uint16(opts.DevMajor), opts.DevMinor), Umask: uint32(kernel.TaskFromContext(ctx).FSContext().Umask()), }, - Name: name, + Name: linux.CString(name), } return i.newEntry(ctx, name, opts.Mode.FileType(), linux.FUSE_MKNOD, &in) } @@ -541,8 +541,8 @@ func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) // NewSymlink implements kernfs.Inode.NewSymlink. func (i *inode) NewSymlink(ctx context.Context, name, target string) (kernfs.Inode, error) { in := linux.FUSESymLinkIn{ - Name: name, - Target: target, + Name: linux.CString(name), + Target: linux.CString(target), } return i.newEntry(ctx, name, linux.S_IFLNK, linux.FUSE_SYMLINK, &in) } @@ -554,7 +554,7 @@ func (i *inode) Unlink(ctx context.Context, name string, child kernfs.Inode) err log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID) return linuxerr.EINVAL } - in := linux.FUSEUnlinkIn{Name: name} + in := linux.FUSEUnlinkIn{Name: linux.CString(name)} req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in) res, err := i.fs.conn.Call(kernelTask, req) if err != nil { @@ -571,7 +571,7 @@ func (i *inode) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) Mode: uint32(opts.Mode), Umask: uint32(kernel.TaskFromContext(ctx).FSContext().Umask()), }, - Name: name, + Name: linux.CString(name), } return i.newEntry(ctx, name, linux.S_IFDIR, linux.FUSE_MKDIR, &in) } @@ -581,7 +581,7 @@ func (i *inode) RmDir(ctx context.Context, name string, child kernfs.Inode) erro fusefs := i.fs task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx) - in := linux.FUSERmDirIn{Name: name} + in := linux.FUSERmDirIn{Name: linux.CString(name)} req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in) res, err := i.fs.conn.Call(task, req) if err != nil { diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go index 8a72489fa..ec76ec2a4 100644 --- a/pkg/sentry/fsimpl/fuse/request_response.go +++ b/pkg/sentry/fsimpl/fuse/request_response.go @@ -41,7 +41,7 @@ type fuseInitRes struct { } // UnmarshalBytes deserializes src to the initOut attribute in a fuseInitRes. -func (r *fuseInitRes) UnmarshalBytes(src []byte) { +func (r *fuseInitRes) UnmarshalBytes(src []byte) []byte { out := &r.initOut // Introduced before FUSE kernel version 7.13. @@ -70,7 +70,7 @@ func (r *fuseInitRes) UnmarshalBytes(src []byte) { out.MaxPages = uint16(hostarch.ByteOrder.Uint16(src[:2])) src = src[2:] } - _ = src // Remove unused warning. + return src } // SizeBytes is the size of the payload of the FUSE_INIT response. diff --git a/pkg/sentry/fsimpl/fuse/utils_test.go b/pkg/sentry/fsimpl/fuse/utils_test.go index b0bab0066..8d4a2fad3 100644 --- a/pkg/sentry/fsimpl/fuse/utils_test.go +++ b/pkg/sentry/fsimpl/fuse/utils_test.go @@ -15,17 +15,13 @@ package fuse import ( - "io" "testing" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - - "gvisor.dev/gvisor/pkg/hostarch" ) func setup(t *testing.T) *testutil.System { @@ -70,58 +66,3 @@ func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveReque } return fs.conn, &fuseDev.vfsfd, nil } - -type testPayload struct { - marshal.StubMarshallable - data uint32 -} - -// SizeBytes implements marshal.Marshallable.SizeBytes. -func (t *testPayload) SizeBytes() int { - return 4 -} - -// MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (t *testPayload) MarshalBytes(dst []byte) { - hostarch.ByteOrder.PutUint32(dst[:4], t.data) -} - -// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (t *testPayload) UnmarshalBytes(src []byte) { - *t = testPayload{data: hostarch.ByteOrder.Uint32(src[:4])} -} - -// Packed implements marshal.Marshallable.Packed. -func (t *testPayload) Packed() bool { - return true -} - -// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. -func (t *testPayload) MarshalUnsafe(dst []byte) { - t.MarshalBytes(dst) -} - -// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. -func (t *testPayload) UnmarshalUnsafe(src []byte) { - t.UnmarshalBytes(src) -} - -// CopyOutN implements marshal.Marshallable.CopyOutN. -func (t *testPayload) CopyOutN(task marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { - panic("not implemented") -} - -// CopyOut implements marshal.Marshallable.CopyOut. -func (t *testPayload) CopyOut(task marshal.CopyContext, addr hostarch.Addr) (int, error) { - panic("not implemented") -} - -// CopyIn implements marshal.Marshallable.CopyIn. -func (t *testPayload) CopyIn(task marshal.CopyContext, addr hostarch.Addr) (int, error) { - panic("not implemented") -} - -// WriteTo implements io.WriterTo.WriteTo. -func (t *testPayload) WriteTo(w io.Writer) (int64, error) { - panic("not implemented") -} diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index fb213d109..09b148164 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -212,8 +212,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { phdrs := make([]elf.ProgHeader, hdr.Phnum) for i := range phdrs { var prog64 linux.ElfProg64 - prog64.UnmarshalUnsafe(phdrBuf[:prog64Size]) - phdrBuf = phdrBuf[prog64Size:] + phdrBuf = prog64.UnmarshalUnsafe(phdrBuf) phdrs[i] = elf.ProgHeader{ Type: elf.ProgType(prog64.Type), Flags: elf.ProgFlag(prog64.Flags), diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 6077b2150..4b036b323 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -17,6 +17,8 @@ package control import ( + "time" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/context" @@ -29,7 +31,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "time" ) // SCMCredentials represents a SCM_CREDENTIALS socket control message. @@ -63,10 +64,10 @@ type RightsFiles []*fs.File // NewSCMRights creates a new SCM_RIGHTS socket control message representation // using local sentry FDs. -func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) { +func NewSCMRights(t *kernel.Task, fds []primitive.Int32) (SCMRights, error) { files := make(RightsFiles, 0, len(fds)) for _, fd := range fds { - file := t.GetFile(fd) + file := t.GetFile(int32(fd)) if file == nil { files.Release(t) return nil, linuxerr.EBADF @@ -486,26 +487,25 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) (socket.ControlMessages, error) { var ( cmsgs socket.ControlMessages - fds linux.ControlMessageRights + fds []primitive.Int32 ) - for i := 0; i < len(buf); { - if i+linux.SizeOfControlMessageHeader > len(buf) { + for len(buf) > 0 { + if linux.SizeOfControlMessageHeader > len(buf) { return cmsgs, linuxerr.EINVAL } var h linux.ControlMessageHeader - h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) + buf = h.UnmarshalUnsafe(buf) if h.Length < uint64(linux.SizeOfControlMessageHeader) { return socket.ControlMessages{}, linuxerr.EINVAL } - if h.Length > uint64(len(buf)-i) { - return socket.ControlMessages{}, linuxerr.EINVAL - } - i += linux.SizeOfControlMessageHeader length := int(h.Length) - linux.SizeOfControlMessageHeader + if length > len(buf) { + return socket.ControlMessages{}, linuxerr.EINVAL + } switch h.Level { case linux.SOL_SOCKET: @@ -518,11 +518,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, linuxerr.EINVAL } - for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight { - fds = append(fds, int32(hostarch.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) - } - - i += bits.AlignUp(length, width) + curFDs := make([]primitive.Int32, numRights) + primitive.UnmarshalUnsafeInt32Slice(curFDs, buf[:rightsSize]) + fds = append(fds, curFDs...) case linux.SCM_CREDENTIALS: if length < linux.SizeOfControlMessageCredentials { @@ -530,23 +528,21 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var creds linux.ControlMessageCredentials - creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) + creds.UnmarshalUnsafe(buf) scmCreds, err := NewSCMCredentials(t, creds) if err != nil { return socket.ControlMessages{}, err } cmsgs.Unix.Credentials = scmCreds - i += bits.AlignUp(length, width) case linux.SO_TIMESTAMP: if length < linux.SizeOfTimeval { return socket.ControlMessages{}, linuxerr.EINVAL } var ts linux.Timeval - ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) + ts.UnmarshalUnsafe(buf) cmsgs.IP.Timestamp = ts.ToTime() cmsgs.IP.HasTimestamp = true - i += bits.AlignUp(length, width) default: // Unknown message type. @@ -560,9 +556,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } cmsgs.IP.HasTOS = true var tos primitive.Uint8 - tos.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTOS]) + tos.UnmarshalUnsafe(buf) cmsgs.IP.TOS = uint8(tos) - i += bits.AlignUp(length, width) case linux.IP_PKTINFO: if length < linux.SizeOfControlMessageIPPacketInfo { @@ -571,19 +566,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) cmsgs.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - packetInfo.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageIPPacketInfo]) - + packetInfo.UnmarshalUnsafe(buf) cmsgs.IP.PacketInfo = packetInfo - i += bits.AlignUp(length, width) case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet if length < addr.SizeBytes() { return socket.ControlMessages{}, linuxerr.EINVAL } - addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) + addr.UnmarshalUnsafe(buf) cmsgs.IP.OriginalDstAddress = &addr - i += bits.AlignUp(length, width) case linux.IP_RECVERR: var errCmsg linux.SockErrCMsgIPv4 @@ -591,9 +583,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, linuxerr.EINVAL } - errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + errCmsg.UnmarshalBytes(buf) cmsgs.IP.SockErr = &errCmsg - i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, linuxerr.EINVAL @@ -606,18 +597,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } cmsgs.IP.HasTClass = true var tclass primitive.Uint32 - tclass.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTClass]) + tclass.UnmarshalUnsafe(buf) cmsgs.IP.TClass = uint32(tclass) - i += bits.AlignUp(length, width) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 if length < addr.SizeBytes() { return socket.ControlMessages{}, linuxerr.EINVAL } - addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) + addr.UnmarshalUnsafe(buf) cmsgs.IP.OriginalDstAddress = &addr - i += bits.AlignUp(length, width) case linux.IPV6_RECVERR: var errCmsg linux.SockErrCMsgIPv6 @@ -625,9 +614,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, linuxerr.EINVAL } - errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + errCmsg.UnmarshalBytes(buf) cmsgs.IP.SockErr = &errCmsg - i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, linuxerr.EINVAL @@ -635,6 +623,11 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) default: return socket.ControlMessages{}, linuxerr.EINVAL } + if shift := bits.AlignUp(length, width); shift > len(buf) { + buf = buf[:0] + } else { + buf = buf[shift:] + } } if cmsgs.Unix.Credentials == nil { diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go index 0a989cbeb..a638cb955 100644 --- a/pkg/sentry/socket/control/control_vfs2.go +++ b/pkg/sentry/socket/control/control_vfs2.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -45,10 +46,10 @@ type RightsFilesVFS2 []*vfs.FileDescription // NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message // representation using local sentry FDs. -func NewSCMRightsVFS2(t *kernel.Task, fds []int32) (SCMRightsVFS2, error) { +func NewSCMRightsVFS2(t *kernel.Task, fds []primitive.Int32) (SCMRightsVFS2, error) { files := make(RightsFilesVFS2, 0, len(fds)) for _, fd := range fds { - file := t.GetFileVFS2(fd) + file := t.GetFileVFS2(int32(fd)) if file == nil { files.Release(t) return nil, linuxerr.EBADF diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 6e2318f75..a31f3ebec 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -578,7 +578,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.SO_TIMESTAMP: controlMessages.IP.HasTimestamp = true ts := linux.Timeval{} - ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval]) + ts.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.Timestamp = ts.ToTime() } @@ -587,18 +587,18 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.IP_TOS: controlMessages.IP.HasTOS = true var tos primitive.Uint8 - tos.UnmarshalUnsafe(unixCmsg.Data[:tos.SizeBytes()]) + tos.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.TOS = uint8(tos) case linux.IP_PKTINFO: controlMessages.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - packetInfo.UnmarshalUnsafe(unixCmsg.Data[:packetInfo.SizeBytes()]) + packetInfo.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.PacketInfo = packetInfo case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet - addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.OriginalDstAddress = &addr case unix.IP_RECVERR: @@ -612,12 +612,12 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.IPV6_TCLASS: controlMessages.IP.HasTClass = true var tclass primitive.Uint32 - tclass.UnmarshalUnsafe(unixCmsg.Data[:tclass.SizeBytes()]) + tclass.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.TClass = uint32(tclass) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 - addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.OriginalDstAddress = &addr case unix.IPV6_RECVERR: @@ -631,7 +631,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.TCP_INQ: controlMessages.IP.HasInq = true var inq primitive.Int32 - inq.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfControlMessageInq]) + inq.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.Inq = int32(inq) } } diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 61111ac6c..c84ab3fb7 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -138,7 +138,7 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), unix.SizeofIfInfomsg) } var ifinfo linux.InterfaceInfoMessage - ifinfo.UnmarshalUnsafe(link.Data[:ifinfo.SizeBytes()]) + ifinfo.UnmarshalUnsafe(link.Data) inetIF := inet.Interface{ DeviceType: ifinfo.Type, Flags: ifinfo.Flags, @@ -169,7 +169,7 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), unix.SizeofIfAddrmsg) } var ifaddr linux.InterfaceAddrMessage - ifaddr.UnmarshalUnsafe(addr.Data[:ifaddr.SizeBytes()]) + ifaddr.UnmarshalUnsafe(addr.Data) inetAddr := inet.InterfaceAddr{ Family: ifaddr.Family, PrefixLen: ifaddr.PrefixLen, @@ -201,7 +201,7 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) } var ifRoute linux.RouteMessage - ifRoute.UnmarshalUnsafe(routeMsg.Data[:ifRoute.SizeBytes()]) + ifRoute.UnmarshalUnsafe(routeMsg.Data) inetRoute := inet.Route{ Family: ifRoute.Family, DstLen: ifRoute.DstLen, diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 3f1b4a17b..7606d2bbb 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -80,9 +80,8 @@ func marshalEntryMatch(name string, data []byte) []byte { copy(matcher.Name[:], name) buf := make([]byte, size) - entryLen := matcher.XTEntryMatch.SizeBytes() - matcher.XTEntryMatch.MarshalUnsafe(buf[:entryLen]) - copy(buf[entryLen:], matcher.Data) + bufRemain := matcher.XTEntryMatch.MarshalUnsafe(buf) + copy(bufRemain, matcher.Data) return buf } diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index af31cbc5b..6cbfee8b6 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -141,10 +141,9 @@ func modifyEntries4(task *kernel.Task, stk *stack.Stack, optVal []byte, replace nflog("optVal has insufficient size for entry %d", len(optVal)) return nil, syserr.ErrInvalidArgument } - var entry linux.IPTEntry - entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()]) initialOptValLen := len(optVal) - optVal = optVal[entry.SizeBytes():] + var entry linux.IPTEntry + optVal = entry.UnmarshalUnsafe(optVal) if entry.TargetOffset < linux.SizeOfIPTEntry { nflog("entry has too-small target offset %d", entry.TargetOffset) diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 6cefe0b9c..902707abf 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -144,10 +144,9 @@ func modifyEntries6(task *kernel.Task, stk *stack.Stack, optVal []byte, replace nflog("optVal has insufficient size for entry %d", len(optVal)) return nil, syserr.ErrInvalidArgument } - var entry linux.IP6TEntry - entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()]) initialOptValLen := len(optVal) - optVal = optVal[entry.SizeBytes():] + var entry linux.IP6TEntry + optVal = entry.UnmarshalUnsafe(optVal) if entry.TargetOffset < linux.SizeOfIP6TEntry { nflog("entry has too-small target offset %d", entry.TargetOffset) diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 01f2f8c77..f2cdf091d 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -176,9 +176,7 @@ func setHooksAndUnderflow(info *linux.IPTGetinfo, table stack.Table, offset uint // net/ipv4/netfilter/ip_tables.c:translate_table for reference. func SetEntries(task *kernel.Task, stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { var replace linux.IPTReplace - replaceBuf := optVal[:linux.SizeOfIPTReplace] - optVal = optVal[linux.SizeOfIPTReplace:] - replace.UnmarshalBytes(replaceBuf) + optVal = replace.UnmarshalBytes(optVal) var table stack.Table switch replace.Name.String() { @@ -306,8 +304,7 @@ func parseMatchers(task *kernel.Task, filter stack.IPHeaderFilter, optVal []byte return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal)) } var match linux.XTEntryMatch - buf := optVal[:match.SizeBytes()] - match.UnmarshalUnsafe(buf) + match.UnmarshalUnsafe(optVal) nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match) // Check some invariants. diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index 6eff2ae65..d83d7e535 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -73,7 +73,7 @@ func (ownerMarshaler) unmarshal(task *kernel.Task, buf []byte, filter stack.IPHe // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.IPTOwnerInfo - matchData.UnmarshalUnsafe(buf[:linux.SizeOfIPTOwnerInfo]) + matchData.UnmarshalUnsafe(buf) nflog("parsed IPTOwnerInfo: %+v", matchData) var owner OwnerMatcher diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index b9c15daab..eaf601543 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -199,7 +199,7 @@ func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } var standardTarget linux.XTStandardTarget - standardTarget.UnmarshalUnsafe(buf[:standardTarget.SizeBytes()]) + standardTarget.UnmarshalUnsafe(buf) if standardTarget.Verdict < 0 { // A Verdict < 0 indicates a non-jump verdict. @@ -253,7 +253,6 @@ func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar return nil, syserr.ErrInvalidArgument } var errTgt linux.XTErrorTarget - buf = buf[:linux.SizeOfXTErrorTarget] errTgt.UnmarshalUnsafe(buf) // Error targets are used in 2 cases: @@ -316,7 +315,6 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( } var rt linux.XTRedirectTarget - buf = buf[:linux.SizeOfXTRedirectTarget] rt.UnmarshalUnsafe(buf) // Copy linux.XTRedirectTarget to stack.RedirectTarget. @@ -405,8 +403,7 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar } var natRange linux.NFNATRange - buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] - natRange.UnmarshalUnsafe(buf) + natRange.UnmarshalUnsafe(buf[linux.SizeOfXTEntryTarget:]) // We don't support port or address ranges. if natRange.MinAddr != natRange.MaxAddr { @@ -477,7 +474,6 @@ func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta } var st linux.XTSNATTarget - buf = buf[:linux.SizeOfXTSNATTarget] st.UnmarshalUnsafe(buf) // Copy linux.XTSNATTarget to stack.SNATTarget. @@ -557,8 +553,7 @@ func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta } var natRange linux.NFNATRange - buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] - natRange.UnmarshalUnsafe(buf) + natRange.UnmarshalUnsafe(buf[linux.SizeOfXTEntryTarget:]) // TODO(gvisor.dev/issue/5697): Support port or address ranges. if natRange.MinAddr != natRange.MaxAddr { @@ -621,7 +616,9 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.T return nil, syserr.ErrInvalidArgument } var target linux.XTEntryTarget - target.UnmarshalUnsafe(optVal[:target.SizeBytes()]) + // Do not advance optVal as targetMake.unmarshal() may unmarshal + // XTEntryTarget again but with some added fields. + target.UnmarshalUnsafe(optVal) return unmarshalTarget(target, filter, optVal) } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index e5b73a976..a621a6a16 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -59,7 +59,7 @@ func (tcpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderF // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.XTTCP - matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()]) + matchData.UnmarshalUnsafe(buf) nflog("parseMatchers: parsed XTTCP: %+v", matchData) if matchData.Option != 0 || diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index aa72ee70c..2ca854764 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -59,7 +59,7 @@ func (udpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderF // For alignment reasons, the match's total size may exceed what's // strictly necessary to hold matchData. var matchData linux.XTUDP - matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()]) + matchData.UnmarshalUnsafe(buf) nflog("parseMatchers: parsed XTUDP: %+v", matchData) if matchData.InverseFlags != 0 { diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go index 968968469..1604b2792 100644 --- a/pkg/sentry/socket/netlink/message_test.go +++ b/pkg/sentry/socket/netlink/message_test.go @@ -25,33 +25,14 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netlink" ) -type dummyNetlinkMsg struct { - marshal.StubMarshallable - Foo uint16 -} - -func (*dummyNetlinkMsg) SizeBytes() int { - return 2 -} - -func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) { - p := primitive.Uint16(m.Foo) - p.MarshalUnsafe(dst) -} - -func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) { - var p primitive.Uint16 - p.UnmarshalUnsafe(src) - m.Foo = uint16(p) -} - func TestParseMessage(t *testing.T) { + dummyNetlinkMsg := primitive.Uint16(0x3130) tests := []struct { desc string input []byte header linux.NetlinkMessageHeader - dataMsg *dummyNetlinkMsg + dataMsg marshal.Marshallable restLen int ok bool }{ @@ -72,9 +53,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 0, ok: true, }, @@ -96,9 +75,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 1, ok: true, }, @@ -119,9 +96,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 0, ok: true, }, @@ -143,9 +118,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 0, ok: true, }, @@ -199,11 +172,11 @@ func TestParseMessage(t *testing.T) { t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header) } - dataMsg := &dummyNetlinkMsg{} - _, dataOk := msg.GetData(dataMsg) + var dataMsg primitive.Uint16 + _, dataOk := msg.GetData(&dataMsg) if !dataOk { t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk) - } else if !reflect.DeepEqual(dataMsg, test.dataMsg) { + } else if !reflect.DeepEqual(&dataMsg, test.dataMsg) { t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg) } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 267155807..19c8f340d 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -223,7 +223,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) { } var sa linux.SockAddrNetlink - sa.UnmarshalUnsafe(b[:sa.SizeBytes()]) + sa.UnmarshalUnsafe(b) if sa.Family != linux.AF_NETLINK { return nil, syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index c35cf06f6..e38c4e5da 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -660,7 +660,7 @@ func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) < sockAddrLinkSize { return syserr.ErrInvalidArgument } - a.UnmarshalBytes(sockaddr[:sockAddrLinkSize]) + a.UnmarshalBytes(sockaddr) addr = tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), @@ -1839,7 +1839,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - v.UnmarshalBytes(optVal[:linux.SizeOfTimeval]) + v.UnmarshalBytes(optVal) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1852,7 +1852,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - v.UnmarshalBytes(optVal[:linux.SizeOfTimeval]) + v.UnmarshalBytes(optVal) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1883,7 +1883,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Linger - v.UnmarshalBytes(optVal[:linux.SizeOfLinger]) + v.UnmarshalBytes(optVal) if v != (linux.Linger{}) { socket.SetSockOptEmitUnimplementedEvent(t, name) @@ -2222,12 +2222,12 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR if len(optVal) >= inetMulticastRequestWithNICSize { var req linux.InetMulticastRequestWithNIC - req.UnmarshalUnsafe(optVal[:inetMulticastRequestWithNICSize]) + req.UnmarshalUnsafe(optVal) return req, nil } var req linux.InetMulticastRequestWithNIC - req.InetMulticastRequest.UnmarshalUnsafe(optVal[:inetMulticastRequestSize]) + req.InetMulticastRequest.UnmarshalUnsafe(optVal) return req, nil } @@ -2237,7 +2237,7 @@ func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syse } var req linux.Inet6MulticastRequest - req.UnmarshalUnsafe(optVal[:inet6MulticastRequestSize]) + req.UnmarshalUnsafe(optVal) return req, nil } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index fc5431eb1..01073df72 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -595,19 +595,19 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { switch family { case unix.AF_INET: var addr linux.SockAddrInet - addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(data) return &addr case unix.AF_INET6: var addr linux.SockAddrInet6 - addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(data) return &addr case unix.AF_UNIX: var addr linux.SockAddrUnix - addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(data) return &addr case unix.AF_NETLINK: var addr linux.SockAddrNetlink - addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(data) return &addr default: panic(fmt.Sprintf("Unsupported socket family %v", family)) @@ -738,7 +738,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInetSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - a.UnmarshalUnsafe(addr[:sockAddrInetSize]) + a.UnmarshalUnsafe(addr) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -751,7 +751,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInet6Size { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - a.UnmarshalUnsafe(addr[:sockAddrInet6Size]) + a.UnmarshalUnsafe(addr) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -767,7 +767,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrLinkSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - a.UnmarshalUnsafe(addr[:sockAddrLinkSize]) + a.UnmarshalUnsafe(addr) // TODO(https://gvisor.dev/issue/6530): Do not assume all interfaces have // an ethernet address. if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index f4aab25b0..b164d9107 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -161,12 +161,10 @@ var controlMessageType = map[int32]string{ linux.SO_TIMESTAMP: "SO_TIMESTAMP", } -func unmarshalControlMessageRights(src []byte) linux.ControlMessageRights { +func unmarshalControlMessageRights(src []byte) []primitive.Int32 { count := len(src) / linux.SizeOfControlMessageRight - cmr := make(linux.ControlMessageRights, count) - for i, _ := range cmr { - cmr[i] = int32(hostarch.ByteOrder.Uint32(src[i*linux.SizeOfControlMessageRight:])) - } + cmr := make([]primitive.Int32, count) + primitive.UnmarshalUnsafeInt32Slice(cmr, src) return cmr } @@ -182,14 +180,14 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) var strs []string - for i := 0; i < len(buf); { - if i+linux.SizeOfControlMessageHeader > len(buf) { + for len(buf) > 0 { + if linux.SizeOfControlMessageHeader > len(buf) { strs = append(strs, "{invalid control message (too short)}") break } var h linux.ControlMessageHeader - h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) + buf = h.UnmarshalUnsafe(buf) var skipData bool level := "SOL_SOCKET" @@ -204,7 +202,9 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) typ = fmt.Sprint(h.Type) } - if h.Length > uint64(len(buf)-i) { + width := t.Arch().Width() + length := int(h.Length) - linux.SizeOfControlMessageHeader + if length > len(buf) { strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, content extends beyond buffer}", level, @@ -214,9 +214,6 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) break } - i += linux.SizeOfControlMessageHeader - width := t.Arch().Width() - length := int(h.Length) - linux.SizeOfControlMessageHeader if length < 0 { strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, content too short}", @@ -229,78 +226,80 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) if skipData { strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length)) - i += bits.AlignUp(length, width) - continue - } - - switch h.Type { - case linux.SCM_RIGHTS: - rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) - fds := unmarshalControlMessageRights(buf[i : i+rightsSize]) - rights := make([]string, 0, len(fds)) - for _, fd := range fds { - rights = append(rights, fmt.Sprint(fd)) - } + } else { + switch h.Type { + case linux.SCM_RIGHTS: + rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) + fds := unmarshalControlMessageRights(buf[:rightsSize]) + rights := make([]string, 0, len(fds)) + for _, fd := range fds { + rights = append(rights, fmt.Sprint(fd)) + } - strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, content: %s}", - level, - typ, - h.Length, - strings.Join(rights, ","), - )) - - case linux.SCM_CREDENTIALS: - if length < linux.SizeOfControlMessageCredentials { strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, content too short}", + "{level=%s, type=%s, length=%d, content: %s}", level, typ, h.Length, + strings.Join(rights, ","), )) - break - } - var creds linux.ControlMessageCredentials - creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) + case linux.SCM_CREDENTIALS: + if length < linux.SizeOfControlMessageCredentials { + strs = append(strs, fmt.Sprintf( + "{level=%s, type=%s, length=%d, content too short}", + level, + typ, + h.Length, + )) + break + } - strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}", - level, - typ, - h.Length, - creds.PID, - creds.UID, - creds.GID, - )) + var creds linux.ControlMessageCredentials + creds.UnmarshalUnsafe(buf) - case linux.SO_TIMESTAMP: - if length < linux.SizeOfTimeval { strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, content too short}", + "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}", level, typ, h.Length, + creds.PID, + creds.UID, + creds.GID, )) - break - } - var tv linux.Timeval - tv.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) + case linux.SO_TIMESTAMP: + if length < linux.SizeOfTimeval { + strs = append(strs, fmt.Sprintf( + "{level=%s, type=%s, length=%d, content too short}", + level, + typ, + h.Length, + )) + break + } - strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}", - level, - typ, - h.Length, - tv.Sec, - tv.Usec, - )) + var tv linux.Timeval + tv.UnmarshalUnsafe(buf) - default: - panic("unreachable") + strs = append(strs, fmt.Sprintf( + "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}", + level, + typ, + h.Length, + tv.Sec, + tv.Usec, + )) + + default: + panic("unreachable") + } + } + if shift := bits.AlignUp(length, width); shift > len(buf) { + buf = buf[:0] + } else { + buf = buf[shift:] } - i += bits.AlignUp(length, width) } return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", ")) diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index 3e643e77f..db135fd74 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -132,8 +132,7 @@ func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) { g.emit("hostarch.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor) g.shift(bufVar, 8) default: - g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) - g.shiftDynamic(bufVar, accessor) + g.emit("%s = %s.MarshalBytes(%s)\n", bufVar, accessor, bufVar) } } @@ -159,8 +158,7 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) { g.emit("%s = %s(hostarch.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar) g.shift(bufVar, 8) default: - g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) - g.shiftDynamic(bufVar, accessor) + g.emit("%s = %s.UnmarshalBytes(%s)\n", bufVar, accessor, bufVar) g.recordPotentiallyNonPackedField(accessor) } } diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go index bd7741ae5..1f98d9246 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go @@ -56,24 +56,26 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *as g.emit("}\n\n") g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) MarshalBytes(dst []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) g.inIndent(func() { g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst") }) g.emit("}\n") + g.emit("return dst\n") }) g.emit("}\n\n") g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) UnmarshalBytes(src []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) g.inIndent(func() { g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src") }) g.emit("}\n") + g.emit("return src\n") }) g.emit("}\n\n") @@ -87,16 +89,20 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *as g.emit("}\n\n") g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&%s[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) + g.emit("size := %s.SizeBytes()\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&%s[0]), uintptr(size))\n", g.r) + g.emit("return dst[size:]\n") }) g.emit("}\n\n") g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) + g.emit("size := %s.SizeBytes()\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(size))\n", g.r) + g.emit("return src[size:]\n") }) g.emit("}\n\n") diff --git a/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go b/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go index 345020ddc..70ae8ef4a 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go @@ -28,18 +28,18 @@ func (g *interfaceGenerator) emitMarshallableForDynamicType() { g.emit("}\n\n") g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) - g.emit("%s.MarshalBytes(dst)\n", g.r) + g.emit("return %s.MarshalBytes(dst)\n", g.r) }) g.emit("}\n\n") g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { g.emit("// Type %s doesn't have a packed layout in memory, fallback to UnmarshalBytes.\n", g.typeName()) - g.emit("%s.UnmarshalBytes(src)\n", g.r) + g.emit("return %s.UnmarshalBytes(src)\n", g.r) }) g.emit("}\n\n") diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go index ba4b7324e..e2387f032 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go @@ -116,16 +116,26 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) g.emit("}\n\n") g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) MarshalBytes(dst []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { g.marshalPrimitiveScalar(g.r, nt.Name, "dst") + if size, dynamic := g.scalarSize(nt); !dynamic { + g.emit("return dst[%d:]\n", size) + } else { + g.emit("return dst[(*%s)(nil).SizeBytes():]\n", nt.Name) + } }) g.emit("}\n\n") g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) UnmarshalBytes(src []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName()) + if size, dynamic := g.scalarSize(nt); !dynamic { + g.emit("return src[%d:]\n", size) + } else { + g.emit("return src[(*%s)(nil).SizeBytes():]\n", nt.Name) + } }) g.emit("}\n\n") @@ -139,16 +149,20 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) g.emit("}\n\n") g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r) + g.emit("size := %s.SizeBytes()\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(size))\n", g.r) + g.emit("return dst[size:]\n") }) g.emit("}\n\n") g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) + g.emit("size := %s.SizeBytes()\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(size))\n", g.r) + g.emit("return src[size:]\n") }) g.emit("}\n\n") diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go index 4c47218f1..21177d39c 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -29,7 +29,7 @@ func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string { } // areFieldsPackedExpression returns a go expression checking whether g.t's fields are -// packed. Returns "", false if g.t has no fields that may be potentially +// packed. Returns "", false if g.t has no fields that may be potentially not // packed, otherwise returns <clause>, true, where <clause> is an expression // like "t.a.Packed() && t.b.Packed() && t.c.Packed()". func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { @@ -136,7 +136,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("\n}\n\n") g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) MarshalBytes(dst []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { forEachStructField(st, fieldDispatcher{ primitive: func(n, t *ast.Ident) { @@ -186,11 +186,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n") }, }.dispatch) + // All cases above shift the buffer appropriately. + g.emit("return dst\n") }) g.emit("}\n\n") g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) UnmarshalBytes(src []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { forEachStructField(st, fieldDispatcher{ primitive: func(n, t *ast.Ident) { @@ -242,6 +244,8 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n") }, }.dispatch) + // All cases above shift the buffer appropriately. + g.emit("return src\n") }) g.emit("}\n\n") @@ -263,25 +267,27 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n\n") g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) - g.emit("%s.MarshalBytes(dst)\n", g.r) + g.emit("return %s.MarshalBytes(dst)\n", g.r) } if thisPacked { g.recordUsedImport("gohacks") g.recordUsedImport("unsafe") + fastMarshal := func() { + g.emit("size := %s.SizeBytes()\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(size))\n", g.r) + g.emit("return dst[size:]\n") + } if cond, ok := g.areFieldsPackedExpression(); ok { g.emit("if %s {\n", cond) - g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r) - }) - g.emit("} else {\n") - g.inIndent(fallback) + g.inIndent(fastMarshal) g.emit("}\n") + fallback() } else { - g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r) + fastMarshal() } } else { fallback() @@ -290,24 +296,27 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n\n") g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fallback to UnmarshalBytes.\n", g.typeName()) - g.emit("%s.UnmarshalBytes(src)\n", g.r) + g.emit("return %s.UnmarshalBytes(src)\n", g.r) } if thisPacked { g.recordUsedImport("gohacks") + g.recordUsedImport("unsafe") + fastUnmarshal := func() { + g.emit("size := %s.SizeBytes()\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(size))\n", g.r) + g.emit("return src[size:]\n") + } if cond, ok := g.areFieldsPackedExpression(); ok { g.emit("if %s {\n", cond) - g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) - }) - g.emit("} else {\n") - g.inIndent(fallback) + g.inIndent(fastUnmarshal) g.emit("}\n") + fallback() } else { - g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) + fastUnmarshal() } } else { fallback() @@ -456,7 +465,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, g.emit("limit := length/size\n") g.emit("for idx := 0; idx < limit; idx++ {\n") g.inIndent(func() { - g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n") + g.emit("buf = dst[idx].UnmarshalBytes(buf)\n") }) g.emit("}\n\n") @@ -465,8 +474,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, g.emit("// result in unmarshalling zero values for some parts of the object.\n") g.emit("if length%size != 0 {\n") g.inIndent(func() { - g.emit("idx := limit\n") - g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n") + g.emit("dst[limit].UnmarshalBytes(buf)\n") }) g.emit("}\n\n") @@ -507,9 +515,10 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) g.emit("buf := cc.CopyScratchBuffer(size * count)\n") + g.emit("curBuf := buf\n") g.emit("for idx := 0; idx < count; idx++ {\n") g.inIndent(func() { - g.emit("src[idx].MarshalBytes(buf[size*idx:size*(idx+1)])\n") + g.emit("curBuf = src[idx].MarshalBytes(curBuf)\n") }) g.emit("}\n") g.emit("return cc.CopyOutBytes(addr, buf)\n") @@ -550,7 +559,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) g.emit("for idx := 0; idx < count; idx++ {\n") g.inIndent(func() { - g.emit("src[idx].MarshalBytes(dst[size*idx:(size)*(idx+1)])\n") + g.emit("dst = src[idx].MarshalBytes(dst)\n") }) g.emit("}\n") g.emit("return size * count, nil\n") @@ -589,7 +598,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) g.emit("for idx := 0; idx < count; idx++ {\n") g.inIndent(func() { - g.emit("dst[idx].UnmarshalBytes(src[size*idx:size*(idx+1)])\n") + g.emit("src = dst[idx].UnmarshalBytes(src)\n") }) g.emit("}\n") g.emit("return size * count, nil\n") diff --git a/tools/go_marshal/test/dynamic.go b/tools/go_marshal/test/dynamic.go index 9a812efe9..46b446392 100644 --- a/tools/go_marshal/test/dynamic.go +++ b/tools/go_marshal/test/dynamic.go @@ -31,25 +31,26 @@ func (t *Type12Dynamic) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (t *Type12Dynamic) MarshalBytes(dst []byte) { - t.X.MarshalBytes(dst) - dst = dst[t.X.SizeBytes():] - for i, x := range t.Y { - x.MarshalBytes(dst[i*8 : (i+1)*8]) +func (t *Type12Dynamic) MarshalBytes(dst []byte) []byte { + dst = t.X.MarshalBytes(dst) + for _, x := range t.Y { + dst = x.MarshalBytes(dst) } + return dst } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (t *Type12Dynamic) UnmarshalBytes(src []byte) { - t.X.UnmarshalBytes(src) +func (t *Type12Dynamic) UnmarshalBytes(src []byte) []byte { + src = t.X.UnmarshalBytes(src) if t.Y != nil { t.Y = t.Y[:0] } - for i := t.X.SizeBytes(); i < len(src); i += 8 { + for len(src) > 0 { var x primitive.Int64 - x.UnmarshalBytes(src[i:]) + src = x.UnmarshalBytes(src) t.Y = append(t.Y, x) } + return src } // Type13Dynamic is a dynamically sized struct which depends on the @@ -67,17 +68,16 @@ func (t *Type13Dynamic) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (t *Type13Dynamic) MarshalBytes(dst []byte) { +func (t *Type13Dynamic) MarshalBytes(dst []byte) []byte { strLen := primitive.Uint32(len(*t)) - strLen.MarshalBytes(dst) - dst = dst[strLen.SizeBytes():] - copy(dst[:strLen], *t) + dst = strLen.MarshalBytes(dst) + return dst[copy(dst[:strLen], *t):] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (t *Type13Dynamic) UnmarshalBytes(src []byte) { +func (t *Type13Dynamic) UnmarshalBytes(src []byte) []byte { var strLen primitive.Uint32 - strLen.UnmarshalBytes(src) - src = src[strLen.SizeBytes():] + src = strLen.UnmarshalBytes(src) *t = Type13Dynamic(src[:strLen]) + return src[strLen:] } |