diff options
41 files changed, 629 insertions, 752 deletions
diff --git a/pkg/abi/linux/fuse.go b/pkg/abi/linux/fuse.go index 1070b457c..1112dadd6 100644 --- a/pkg/abi/linux/fuse.go +++ b/pkg/abi/linux/fuse.go @@ -352,6 +352,22 @@ type FUSEEntryOut struct { Attr FUSEAttr } +// CString represents a null terminated string which can be marshalled. +type CString string + +// MarshalBytes implements marshal.Marshallable.MarshalBytes. +func (s *CString) MarshalBytes(buf []byte) []byte { + copy(buf, *s) + buf[len(*s)] = 0 // null char + return buf[s.SizeBytes():] +} + +// SizeBytes implements marshal.Marshallable.SizeBytes. +func (s *CString) SizeBytes() int { + // 1 extra byte for null-terminated string. + return len(*s) + 1 +} + // FUSELookupIn is the request sent by the kernel to the daemon // to look up a file name. // @@ -360,18 +376,17 @@ type FUSELookupIn struct { marshal.StubMarshallable // Name is a file name to be looked up. - Name string + Name CString } // MarshalBytes serializes r.name to the dst buffer. -func (r *FUSELookupIn) MarshalBytes(buf []byte) { - copy(buf, r.Name) +func (r *FUSELookupIn) MarshalBytes(buf []byte) []byte { + return r.Name.MarshalBytes(buf) } // SizeBytes is the size of the memory representation of FUSELookupIn. -// 1 extra byte for null-terminated string. func (r *FUSELookupIn) SizeBytes() int { - return len(r.Name) + 1 + return r.Name.SizeBytes() } // MAX_NON_LFS indicates the maximum offset without large file support. @@ -530,19 +545,18 @@ type FUSECreateIn struct { CreateMeta FUSECreateMeta // Name is the name of the node to create. - Name string + Name CString } // MarshalBytes serializes r.CreateMeta and r.Name to the dst buffer. -func (r *FUSECreateIn) MarshalBytes(buf []byte) { - r.CreateMeta.MarshalBytes(buf[:r.CreateMeta.SizeBytes()]) - copy(buf[r.CreateMeta.SizeBytes():], r.Name) +func (r *FUSECreateIn) MarshalBytes(buf []byte) []byte { + buf = r.CreateMeta.MarshalBytes(buf) + return r.Name.MarshalBytes(buf) } // SizeBytes is the size of the memory representation of FUSECreateIn. -// 1 extra byte for null-terminated string. func (r *FUSECreateIn) SizeBytes() int { - return r.CreateMeta.SizeBytes() + len(r.Name) + 1 + return r.CreateMeta.SizeBytes() + r.Name.SizeBytes() } // FUSEMknodMeta contains all the static fields of FUSEMknodIn, @@ -573,19 +587,18 @@ type FUSEMknodIn struct { MknodMeta FUSEMknodMeta // Name is the name of the node to create. - Name string + Name CString } // MarshalBytes serializes r.MknodMeta and r.Name to the dst buffer. -func (r *FUSEMknodIn) MarshalBytes(buf []byte) { - r.MknodMeta.MarshalBytes(buf[:r.MknodMeta.SizeBytes()]) - copy(buf[r.MknodMeta.SizeBytes():], r.Name) +func (r *FUSEMknodIn) MarshalBytes(buf []byte) []byte { + buf = r.MknodMeta.MarshalBytes(buf) + return r.Name.MarshalBytes(buf) } // SizeBytes is the size of the memory representation of FUSEMknodIn. -// 1 extra byte for null-terminated string. func (r *FUSEMknodIn) SizeBytes() int { - return r.MknodMeta.SizeBytes() + len(r.Name) + 1 + return r.MknodMeta.SizeBytes() + r.Name.SizeBytes() } // FUSESymLinkIn is the request sent by the kernel to the daemon, @@ -596,30 +609,30 @@ type FUSESymLinkIn struct { marshal.StubMarshallable // Name of symlink to create. - Name string + Name CString // Target of the symlink. - Target string + Target CString } // MarshalBytes serializes r.Name and r.Target to the dst buffer. -// Left null-termination at end of r.Name and r.Target. -func (r *FUSESymLinkIn) MarshalBytes(buf []byte) { - copy(buf, r.Name) - copy(buf[len(r.Name)+1:], r.Target) +func (r *FUSESymLinkIn) MarshalBytes(buf []byte) []byte { + buf = r.Name.MarshalBytes(buf) + return r.Target.MarshalBytes(buf) } // SizeBytes is the size of the memory representation of FUSESymLinkIn. -// 2 extra bytes for null-terminated string. func (r *FUSESymLinkIn) SizeBytes() int { - return len(r.Name) + len(r.Target) + 2 + return r.Name.SizeBytes() + r.Target.SizeBytes() } // FUSEEmptyIn is used by operations without request body. type FUSEEmptyIn struct{ marshal.StubMarshallable } // MarshalBytes do nothing for marshal. -func (r *FUSEEmptyIn) MarshalBytes(buf []byte) {} +func (r *FUSEEmptyIn) MarshalBytes(buf []byte) []byte { + return buf +} // SizeBytes is 0 for empty request. func (r *FUSEEmptyIn) SizeBytes() int { @@ -649,19 +662,18 @@ type FUSEMkdirIn struct { MkdirMeta FUSEMkdirMeta // Name of the directory to create. - Name string + Name CString } // MarshalBytes serializes r.MkdirMeta and r.Name to the dst buffer. -func (r *FUSEMkdirIn) MarshalBytes(buf []byte) { - r.MkdirMeta.MarshalBytes(buf[:r.MkdirMeta.SizeBytes()]) - copy(buf[r.MkdirMeta.SizeBytes():], r.Name) +func (r *FUSEMkdirIn) MarshalBytes(buf []byte) []byte { + buf = r.MkdirMeta.MarshalBytes(buf) + return r.Name.MarshalBytes(buf) } // SizeBytes is the size of the memory representation of FUSEMkdirIn. -// 1 extra byte for null-terminated Name string. func (r *FUSEMkdirIn) SizeBytes() int { - return r.MkdirMeta.SizeBytes() + len(r.Name) + 1 + return r.MkdirMeta.SizeBytes() + r.Name.SizeBytes() } // FUSERmDirIn is the request sent by the kernel to the daemon @@ -672,17 +684,17 @@ type FUSERmDirIn struct { marshal.StubMarshallable // Name is a directory name to be removed. - Name string + Name CString } // MarshalBytes serializes r.name to the dst buffer. -func (r *FUSERmDirIn) MarshalBytes(buf []byte) { - copy(buf, r.Name) +func (r *FUSERmDirIn) MarshalBytes(buf []byte) []byte { + return r.Name.MarshalBytes(buf) } // SizeBytes is the size of the memory representation of FUSERmDirIn. func (r *FUSERmDirIn) SizeBytes() int { - return len(r.Name) + 1 + return r.Name.SizeBytes() } // FUSEDirents is a list of Dirents received from the FUSE daemon server. @@ -738,7 +750,7 @@ func (r *FUSEDirents) SizeBytes() int { } // UnmarshalBytes deserializes FUSEDirents from the src buffer. -func (r *FUSEDirents) UnmarshalBytes(src []byte) { +func (r *FUSEDirents) UnmarshalBytes(src []byte) []byte { for { if len(src) <= (*FUSEDirentMeta)(nil).SizeBytes() { break @@ -754,11 +766,10 @@ func (r *FUSEDirents) UnmarshalBytes(src []byte) { // to do this. Linux allocates 1 page to store all the dirents and then // simply reads them from the page. var dirent FUSEDirent - dirent.UnmarshalBytes(src) + src = dirent.UnmarshalBytes(src) r.Dirents = append(r.Dirents, &dirent) - - src = src[dirent.SizeBytes():] } + return src } // SizeBytes is the size of the memory representation of FUSEDirent. @@ -772,20 +783,20 @@ func (r *FUSEDirent) SizeBytes() int { } // UnmarshalBytes deserializes FUSEDirent from the src buffer. -func (r *FUSEDirent) UnmarshalBytes(src []byte) { - r.Meta.UnmarshalBytes(src) - src = src[r.Meta.SizeBytes():] +func (r *FUSEDirent) UnmarshalBytes(src []byte) []byte { + src = r.Meta.UnmarshalBytes(src) if r.Meta.NameLen > FUSE_NAME_MAX { // The name is too long and therefore invalid. We don't // need to unmarshal the name since it'll be thrown away. - return + return src } buf := make([]byte, r.Meta.NameLen) name := primitive.ByteSlice(buf) name.UnmarshalBytes(src[:r.Meta.NameLen]) r.Name = string(name) + return src[r.Meta.NameLen:] } // FATTR_* consts are the attribute flags defined in include/uapi/linux/fuse.h. @@ -863,17 +874,15 @@ type FUSEUnlinkIn struct { marshal.StubMarshallable // Name of the node to unlink. - Name string + Name CString } -// MarshalBytes serializes r.name to the dst buffer, which should -// have size len(r.Name) + 1 and last byte set to 0. -func (r *FUSEUnlinkIn) MarshalBytes(buf []byte) { - copy(buf, r.Name) +// MarshalBytes serializes r.name to the dst buffer. +func (r *FUSEUnlinkIn) MarshalBytes(buf []byte) []byte { + return r.Name.MarshalBytes(buf) } // SizeBytes is the size of the memory representation of FUSEUnlinkIn. -// 1 extra byte for null-terminated Name string. func (r *FUSEUnlinkIn) SizeBytes() int { - return len(r.Name) + 1 + return r.Name.SizeBytes() } diff --git a/pkg/abi/linux/msgqueue.go b/pkg/abi/linux/msgqueue.go index 0612a8214..6f8eb4dd9 100644 --- a/pkg/abi/linux/msgqueue.go +++ b/pkg/abi/linux/msgqueue.go @@ -82,15 +82,15 @@ func (b *MsgBuf) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (b *MsgBuf) MarshalBytes(dst []byte) { - b.Type.MarshalUnsafe(dst) - b.Text.MarshalBytes(dst[b.Type.SizeBytes():]) +func (b *MsgBuf) MarshalBytes(dst []byte) []byte { + dst = b.Type.MarshalUnsafe(dst) + return b.Text.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (b *MsgBuf) UnmarshalBytes(src []byte) { - b.Type.UnmarshalUnsafe(src) - b.Text.UnmarshalBytes(src[b.Type.SizeBytes():]) +func (b *MsgBuf) UnmarshalBytes(src []byte) []byte { + src = b.Type.UnmarshalUnsafe(src) + return b.Text.UnmarshalBytes(src) } // MsgInfo is equivelant to struct msginfo. Source: include/uapi/linux/msg.h diff --git a/pkg/abi/linux/netfilter.go b/pkg/abi/linux/netfilter.go index 3fd05483a..1470a5578 100644 --- a/pkg/abi/linux/netfilter.go +++ b/pkg/abi/linux/netfilter.go @@ -144,15 +144,15 @@ func (ke *KernelIPTEntry) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (ke *KernelIPTEntry) MarshalBytes(dst []byte) { - ke.Entry.MarshalUnsafe(dst) - ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) +func (ke *KernelIPTEntry) MarshalBytes(dst []byte) []byte { + dst = ke.Entry.MarshalUnsafe(dst) + return ke.Elems.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) { - ke.Entry.UnmarshalUnsafe(src) - ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) +func (ke *KernelIPTEntry) UnmarshalBytes(src []byte) []byte { + src = ke.Entry.UnmarshalUnsafe(src) + return ke.Elems.UnmarshalBytes(src) } var _ marshal.Marshallable = (*KernelIPTEntry)(nil) @@ -455,23 +455,21 @@ func (ke *KernelIPTGetEntries) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) { - ke.IPTGetEntries.MarshalUnsafe(dst) - marshalledUntil := ke.IPTGetEntries.SizeBytes() +func (ke *KernelIPTGetEntries) MarshalBytes(dst []byte) []byte { + dst = ke.IPTGetEntries.MarshalUnsafe(dst) for i := range ke.Entrytable { - ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) - marshalledUntil += ke.Entrytable[i].SizeBytes() + dst = ke.Entrytable[i].MarshalBytes(dst) } + return dst } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) { - ke.IPTGetEntries.UnmarshalUnsafe(src) - unmarshalledUntil := ke.IPTGetEntries.SizeBytes() +func (ke *KernelIPTGetEntries) UnmarshalBytes(src []byte) []byte { + src = ke.IPTGetEntries.UnmarshalUnsafe(src) for i := range ke.Entrytable { - ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) - unmarshalledUntil += ke.Entrytable[i].SizeBytes() + src = ke.Entrytable[i].UnmarshalBytes(src) } + return src } var _ marshal.Marshallable = (*KernelIPTGetEntries)(nil) diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go index f8c0e891e..aba0202ef 100644 --- a/pkg/abi/linux/netfilter_ipv6.go +++ b/pkg/abi/linux/netfilter_ipv6.go @@ -84,23 +84,21 @@ func (ke *KernelIP6TGetEntries) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (ke *KernelIP6TGetEntries) MarshalBytes(dst []byte) { - ke.IPTGetEntries.MarshalUnsafe(dst) - marshalledUntil := ke.IPTGetEntries.SizeBytes() +func (ke *KernelIP6TGetEntries) MarshalBytes(dst []byte) []byte { + dst = ke.IPTGetEntries.MarshalUnsafe(dst) for i := range ke.Entrytable { - ke.Entrytable[i].MarshalBytes(dst[marshalledUntil:]) - marshalledUntil += ke.Entrytable[i].SizeBytes() + dst = ke.Entrytable[i].MarshalBytes(dst) } + return dst } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (ke *KernelIP6TGetEntries) UnmarshalBytes(src []byte) { - ke.IPTGetEntries.UnmarshalUnsafe(src) - unmarshalledUntil := ke.IPTGetEntries.SizeBytes() +func (ke *KernelIP6TGetEntries) UnmarshalBytes(src []byte) []byte { + src = ke.IPTGetEntries.UnmarshalUnsafe(src) for i := range ke.Entrytable { - ke.Entrytable[i].UnmarshalBytes(src[unmarshalledUntil:]) - unmarshalledUntil += ke.Entrytable[i].SizeBytes() + src = ke.Entrytable[i].UnmarshalBytes(src) } + return src } var _ marshal.Marshallable = (*KernelIP6TGetEntries)(nil) @@ -166,17 +164,19 @@ func (ke *KernelIP6TEntry) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (ke *KernelIP6TEntry) MarshalBytes(dst []byte) { - ke.Entry.MarshalUnsafe(dst) - ke.Elems.MarshalBytes(dst[ke.Entry.SizeBytes():]) +func (ke *KernelIP6TEntry) MarshalBytes(dst []byte) []byte { + dst = ke.Entry.MarshalUnsafe(dst) + return ke.Elems.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (ke *KernelIP6TEntry) UnmarshalBytes(src []byte) { - ke.Entry.UnmarshalUnsafe(src) - ke.Elems.UnmarshalBytes(src[ke.Entry.SizeBytes():]) +func (ke *KernelIP6TEntry) UnmarshalBytes(src []byte) []byte { + src = ke.Entry.UnmarshalUnsafe(src) + return ke.Elems.UnmarshalBytes(src) } +var _ marshal.Marshallable = (*KernelIP6TEntry)(nil) + // IP6TIP contains information for matching a packet's IP header. // It corresponds to struct ip6t_ip6 in // include/uapi/linux/netfilter_ipv6/ip6_tables.h. diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index f60e42997..a31690a04 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -555,9 +555,6 @@ type ControlMessageIPv6PacketInfo struct { // ControlMessageCredentials struct. var SizeOfControlMessageCredentials = (*ControlMessageCredentials)(nil).SizeBytes() -// A ControlMessageRights is an SCM_RIGHTS socket control message. -type ControlMessageRights []int32 - // SizeOfControlMessageRight is the size of a single element in // ControlMessageRights. const SizeOfControlMessageRight = 4 diff --git a/pkg/lisafs/client.go b/pkg/lisafs/client.go index ccf1b9f72..e0f278b5c 100644 --- a/pkg/lisafs/client.go +++ b/pkg/lisafs/client.go @@ -313,7 +313,7 @@ func (c *Client) SyncFDs(ctx context.Context, fds []FDID) error { // implicit conversion to an interface leads to an allocation. // // Precondition: reqMarshal and respUnmarshal must be non-nil. -func (c *Client) SndRcvMessage(m MID, payloadLen uint32, reqMarshal func(dst []byte), respUnmarshal func(src []byte), respFDs []int) error { +func (c *Client) SndRcvMessage(m MID, payloadLen uint32, reqMarshal func(dst []byte) []byte, respUnmarshal func(src []byte) []byte, respFDs []int) error { if !c.IsSupported(m) { return unix.EOPNOTSUPP } diff --git a/pkg/lisafs/message.go b/pkg/lisafs/message.go index 722afd0be..0d7c30ce3 100644 --- a/pkg/lisafs/message.go +++ b/pkg/lisafs/message.go @@ -157,7 +157,6 @@ func MaxMessageSize() uint32 { // TODO(gvisor.dev/issue/6450): Once this is resolved: // * Update manual implementations and function signatures. // * Update RPC handlers and appropriate callers to handle errors correctly. -// * Update manual implementations to get rid of buffer shifting. // UID represents a user ID. // @@ -180,10 +179,10 @@ func (gid GID) Ok() bool { } // NoopMarshal is a noop implementation of marshal.Marshallable.MarshalBytes. -func NoopMarshal([]byte) {} +func NoopMarshal(b []byte) []byte { return b } // NoopUnmarshal is a noop implementation of marshal.Marshallable.UnmarshalBytes. -func NoopUnmarshal([]byte) {} +func NoopUnmarshal(b []byte) []byte { return b } // SizedString represents a string in memory. The marshalled string bytes are // preceded by a uint32 signifying the string length. @@ -195,21 +194,20 @@ func (s *SizedString) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (s *SizedString) MarshalBytes(dst []byte) { +func (s *SizedString) MarshalBytes(dst []byte) []byte { strLen := primitive.Uint32(len(*s)) - strLen.MarshalUnsafe(dst) - dst = dst[strLen.SizeBytes():] + dst = strLen.MarshalUnsafe(dst) // Copy without any allocation. - copy(dst[:strLen], *s) + return dst[copy(dst[:strLen], *s):] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (s *SizedString) UnmarshalBytes(src []byte) { +func (s *SizedString) UnmarshalBytes(src []byte) []byte { var strLen primitive.Uint32 - strLen.UnmarshalUnsafe(src) - src = src[strLen.SizeBytes():] + src = strLen.UnmarshalUnsafe(src) // Take the hit, this leads to an allocation + memcpy. No way around it. *s = SizedString(src[:strLen]) + return src[strLen:] } // StringArray represents an array of SizedStrings in memory. The marshalled @@ -227,22 +225,20 @@ func (s *StringArray) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (s *StringArray) MarshalBytes(dst []byte) { +func (s *StringArray) MarshalBytes(dst []byte) []byte { arrLen := primitive.Uint32(len(*s)) - arrLen.MarshalUnsafe(dst) - dst = dst[arrLen.SizeBytes():] + dst = arrLen.MarshalUnsafe(dst) for _, str := range *s { sstr := SizedString(str) - sstr.MarshalBytes(dst) - dst = dst[sstr.SizeBytes():] + dst = sstr.MarshalBytes(dst) } + return dst } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (s *StringArray) UnmarshalBytes(src []byte) { +func (s *StringArray) UnmarshalBytes(src []byte) []byte { var arrLen primitive.Uint32 - arrLen.UnmarshalUnsafe(src) - src = src[arrLen.SizeBytes():] + src = arrLen.UnmarshalUnsafe(src) if cap(*s) < int(arrLen) { *s = make([]string, arrLen) @@ -252,10 +248,10 @@ func (s *StringArray) UnmarshalBytes(src []byte) { for i := primitive.Uint32(0); i < arrLen; i++ { var sstr SizedString - sstr.UnmarshalBytes(src) - src = src[sstr.SizeBytes():] + src = sstr.UnmarshalBytes(src) (*s)[i] = string(sstr) } + return src } // Inode represents an inode on the remote filesystem. @@ -278,13 +274,13 @@ func (m *MountReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (m *MountReq) MarshalBytes(dst []byte) { - m.MountPath.MarshalBytes(dst) +func (m *MountReq) MarshalBytes(dst []byte) []byte { + return m.MountPath.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (m *MountReq) UnmarshalBytes(src []byte) { - m.MountPath.UnmarshalBytes(src) +func (m *MountReq) UnmarshalBytes(src []byte) []byte { + return m.MountPath.UnmarshalBytes(src) } // MountResp represents a Mount response. @@ -306,28 +302,30 @@ func (m *MountResp) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (m *MountResp) MarshalBytes(dst []byte) { - m.Root.MarshalUnsafe(dst) - dst = dst[m.Root.SizeBytes():] - m.MaxMessageSize.MarshalUnsafe(dst) - dst = dst[m.MaxMessageSize.SizeBytes():] +func (m *MountResp) MarshalBytes(dst []byte) []byte { + dst = m.Root.MarshalUnsafe(dst) + dst = m.MaxMessageSize.MarshalUnsafe(dst) numSupported := primitive.Uint16(len(m.SupportedMs)) - numSupported.MarshalBytes(dst) - dst = dst[numSupported.SizeBytes():] - MarshalUnsafeMIDSlice(m.SupportedMs, dst) + dst = numSupported.MarshalBytes(dst) + n, err := MarshalUnsafeMIDSlice(m.SupportedMs, dst) + if err != nil { + panic(err) + } + return dst[n:] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (m *MountResp) UnmarshalBytes(src []byte) { - m.Root.UnmarshalUnsafe(src) - src = src[m.Root.SizeBytes():] - m.MaxMessageSize.UnmarshalUnsafe(src) - src = src[m.MaxMessageSize.SizeBytes():] +func (m *MountResp) UnmarshalBytes(src []byte) []byte { + src = m.Root.UnmarshalUnsafe(src) + src = m.MaxMessageSize.UnmarshalUnsafe(src) var numSupported primitive.Uint16 - numSupported.UnmarshalBytes(src) - src = src[numSupported.SizeBytes():] + src = numSupported.UnmarshalBytes(src) m.SupportedMs = make([]MID, numSupported) - UnmarshalUnsafeMIDSlice(m.SupportedMs, src) + n, err := UnmarshalUnsafeMIDSlice(m.SupportedMs, src) + if err != nil { + panic(err) + } + return src[n:] } // ChannelResp is the response to the create channel request. @@ -391,17 +389,15 @@ func (w *WalkReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (w *WalkReq) MarshalBytes(dst []byte) { - w.DirFD.MarshalUnsafe(dst) - dst = dst[w.DirFD.SizeBytes():] - w.Path.MarshalBytes(dst) +func (w *WalkReq) MarshalBytes(dst []byte) []byte { + dst = w.DirFD.MarshalUnsafe(dst) + return w.Path.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (w *WalkReq) UnmarshalBytes(src []byte) { - w.DirFD.UnmarshalUnsafe(src) - src = src[w.DirFD.SizeBytes():] - w.Path.UnmarshalBytes(src) +func (w *WalkReq) UnmarshalBytes(src []byte) []byte { + src = w.DirFD.UnmarshalUnsafe(src) + return w.Path.UnmarshalBytes(src) } // WalkStatus is used to indicate the reason for partial/unsuccessful server @@ -438,32 +434,36 @@ func (w *WalkResp) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (w *WalkResp) MarshalBytes(dst []byte) { - w.Status.MarshalUnsafe(dst) - dst = dst[w.Status.SizeBytes():] +func (w *WalkResp) MarshalBytes(dst []byte) []byte { + dst = w.Status.MarshalUnsafe(dst) numInodes := primitive.Uint32(len(w.Inodes)) - numInodes.MarshalUnsafe(dst) - dst = dst[numInodes.SizeBytes():] + dst = numInodes.MarshalUnsafe(dst) - MarshalUnsafeInodeSlice(w.Inodes, dst) + n, err := MarshalUnsafeInodeSlice(w.Inodes, dst) + if err != nil { + panic(err) + } + return dst[n:] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (w *WalkResp) UnmarshalBytes(src []byte) { - w.Status.UnmarshalUnsafe(src) - src = src[w.Status.SizeBytes():] +func (w *WalkResp) UnmarshalBytes(src []byte) []byte { + src = w.Status.UnmarshalUnsafe(src) var numInodes primitive.Uint32 - numInodes.UnmarshalUnsafe(src) - src = src[numInodes.SizeBytes():] + src = numInodes.UnmarshalUnsafe(src) if cap(w.Inodes) < int(numInodes) { w.Inodes = make([]Inode, numInodes) } else { w.Inodes = w.Inodes[:numInodes] } - UnmarshalUnsafeInodeSlice(w.Inodes, src) + n, err := UnmarshalUnsafeInodeSlice(w.Inodes, src) + if err != nil { + panic(err) + } + return src[n:] } // WalkStatResp is used to communicate stat results for WalkStat. @@ -477,26 +477,32 @@ func (w *WalkStatResp) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (w *WalkStatResp) MarshalBytes(dst []byte) { +func (w *WalkStatResp) MarshalBytes(dst []byte) []byte { numStats := primitive.Uint32(len(w.Stats)) - numStats.MarshalUnsafe(dst) - dst = dst[numStats.SizeBytes():] + dst = numStats.MarshalUnsafe(dst) - linux.MarshalUnsafeStatxSlice(w.Stats, dst) + n, err := linux.MarshalUnsafeStatxSlice(w.Stats, dst) + if err != nil { + panic(err) + } + return dst[n:] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (w *WalkStatResp) UnmarshalBytes(src []byte) { +func (w *WalkStatResp) UnmarshalBytes(src []byte) []byte { var numStats primitive.Uint32 - numStats.UnmarshalUnsafe(src) - src = src[numStats.SizeBytes():] + src = numStats.UnmarshalUnsafe(src) if cap(w.Stats) < int(numStats) { w.Stats = make([]linux.Statx, numStats) } else { w.Stats = w.Stats[:numStats] } - linux.UnmarshalUnsafeStatxSlice(w.Stats, src) + n, err := linux.UnmarshalUnsafeStatxSlice(w.Stats, src) + if err != nil { + panic(err) + } + return src[n:] } // OpenAtReq is used to open existing FDs with the specified flags. @@ -536,21 +542,17 @@ func (o *OpenCreateAtReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (o *OpenCreateAtReq) MarshalBytes(dst []byte) { - o.createCommon.MarshalUnsafe(dst) - dst = dst[o.createCommon.SizeBytes():] - o.Name.MarshalBytes(dst) - dst = dst[o.Name.SizeBytes():] - o.Flags.MarshalUnsafe(dst) +func (o *OpenCreateAtReq) MarshalBytes(dst []byte) []byte { + dst = o.createCommon.MarshalUnsafe(dst) + dst = o.Name.MarshalBytes(dst) + return o.Flags.MarshalUnsafe(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (o *OpenCreateAtReq) UnmarshalBytes(src []byte) { - o.createCommon.UnmarshalUnsafe(src) - src = src[o.createCommon.SizeBytes():] - o.Name.UnmarshalBytes(src) - src = src[o.Name.SizeBytes():] - o.Flags.UnmarshalUnsafe(src) +func (o *OpenCreateAtReq) UnmarshalBytes(src []byte) []byte { + src = o.createCommon.UnmarshalUnsafe(src) + src = o.Name.UnmarshalBytes(src) + return o.Flags.UnmarshalUnsafe(src) } // OpenCreateAtResp is used to communicate successful OpenCreateAt results. @@ -573,24 +575,30 @@ func (f *FdArray) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (f *FdArray) MarshalBytes(dst []byte) { +func (f *FdArray) MarshalBytes(dst []byte) []byte { arrLen := primitive.Uint32(len(*f)) - arrLen.MarshalUnsafe(dst) - dst = dst[arrLen.SizeBytes():] - MarshalUnsafeFDIDSlice(*f, dst) + dst = arrLen.MarshalUnsafe(dst) + n, err := MarshalUnsafeFDIDSlice(*f, dst) + if err != nil { + panic(err) + } + return dst[n:] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (f *FdArray) UnmarshalBytes(src []byte) { +func (f *FdArray) UnmarshalBytes(src []byte) []byte { var arrLen primitive.Uint32 - arrLen.UnmarshalUnsafe(src) - src = src[arrLen.SizeBytes():] + src = arrLen.UnmarshalUnsafe(src) if cap(*f) < int(arrLen) { *f = make(FdArray, arrLen) } else { *f = (*f)[:arrLen] } - UnmarshalUnsafeFDIDSlice(*f, src) + n, err := UnmarshalUnsafeFDIDSlice(*f, src) + if err != nil { + panic(err) + } + return src[n:] } // CloseReq is used to close(2) FDs. @@ -604,13 +612,13 @@ func (c *CloseReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (c *CloseReq) MarshalBytes(dst []byte) { - c.FDs.MarshalBytes(dst) +func (c *CloseReq) MarshalBytes(dst []byte) []byte { + return c.FDs.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (c *CloseReq) UnmarshalBytes(src []byte) { - c.FDs.UnmarshalBytes(src) +func (c *CloseReq) UnmarshalBytes(src []byte) []byte { + return c.FDs.UnmarshalBytes(src) } // FsyncReq is used to fsync(2) FDs. @@ -624,13 +632,13 @@ func (f *FsyncReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (f *FsyncReq) MarshalBytes(dst []byte) { - f.FDs.MarshalBytes(dst) +func (f *FsyncReq) MarshalBytes(dst []byte) []byte { + return f.FDs.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (f *FsyncReq) UnmarshalBytes(src []byte) { - f.FDs.UnmarshalBytes(src) +func (f *FsyncReq) UnmarshalBytes(src []byte) []byte { + return f.FDs.UnmarshalBytes(src) } // PReadReq is used to pread(2) on an FD. @@ -654,20 +662,18 @@ func (r *PReadResp) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (r *PReadResp) MarshalBytes(dst []byte) { - r.NumBytes.MarshalUnsafe(dst) - dst = dst[r.NumBytes.SizeBytes():] - copy(dst[:r.NumBytes], r.Buf[:r.NumBytes]) +func (r *PReadResp) MarshalBytes(dst []byte) []byte { + dst = r.NumBytes.MarshalUnsafe(dst) + return dst[copy(dst[:r.NumBytes], r.Buf[:r.NumBytes]):] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (r *PReadResp) UnmarshalBytes(src []byte) { - r.NumBytes.UnmarshalUnsafe(src) - src = src[r.NumBytes.SizeBytes():] +func (r *PReadResp) UnmarshalBytes(src []byte) []byte { + src = r.NumBytes.UnmarshalUnsafe(src) // We expect the client to have already allocated r.Buf. r.Buf probably // (optimally) points to usermem. Directly copy into that. - copy(r.Buf[:r.NumBytes], src[:r.NumBytes]) + return src[copy(r.Buf[:r.NumBytes], src[:r.NumBytes]):] } // PWriteReq is used to pwrite(2) on an FD. @@ -684,28 +690,23 @@ func (w *PWriteReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (w *PWriteReq) MarshalBytes(dst []byte) { - w.Offset.MarshalUnsafe(dst) - dst = dst[w.Offset.SizeBytes():] - w.FD.MarshalUnsafe(dst) - dst = dst[w.FD.SizeBytes():] - w.NumBytes.MarshalUnsafe(dst) - dst = dst[w.NumBytes.SizeBytes():] - copy(dst[:w.NumBytes], w.Buf[:w.NumBytes]) +func (w *PWriteReq) MarshalBytes(dst []byte) []byte { + dst = w.Offset.MarshalUnsafe(dst) + dst = w.FD.MarshalUnsafe(dst) + dst = w.NumBytes.MarshalUnsafe(dst) + return dst[copy(dst[:w.NumBytes], w.Buf[:w.NumBytes]):] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (w *PWriteReq) UnmarshalBytes(src []byte) { - w.Offset.UnmarshalUnsafe(src) - src = src[w.Offset.SizeBytes():] - w.FD.UnmarshalUnsafe(src) - src = src[w.FD.SizeBytes():] - w.NumBytes.UnmarshalUnsafe(src) - src = src[w.NumBytes.SizeBytes():] +func (w *PWriteReq) UnmarshalBytes(src []byte) []byte { + src = w.Offset.UnmarshalUnsafe(src) + src = w.FD.UnmarshalUnsafe(src) + src = w.NumBytes.UnmarshalUnsafe(src) // This is an optimization. Assuming that the server is making this call, it // is safe to just point to src rather than allocating and copying. w.Buf = src[:w.NumBytes] + return src[w.NumBytes:] } // PWriteResp is used to return the result of pwrite(2). @@ -727,17 +728,15 @@ func (m *MkdirAtReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (m *MkdirAtReq) MarshalBytes(dst []byte) { - m.createCommon.MarshalUnsafe(dst) - dst = dst[m.createCommon.SizeBytes():] - m.Name.MarshalBytes(dst) +func (m *MkdirAtReq) MarshalBytes(dst []byte) []byte { + dst = m.createCommon.MarshalUnsafe(dst) + return m.Name.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (m *MkdirAtReq) UnmarshalBytes(src []byte) { - m.createCommon.UnmarshalUnsafe(src) - src = src[m.createCommon.SizeBytes():] - m.Name.UnmarshalBytes(src) +func (m *MkdirAtReq) UnmarshalBytes(src []byte) []byte { + src = m.createCommon.UnmarshalUnsafe(src) + return m.Name.UnmarshalBytes(src) } // MkdirAtResp is the response to a successful MkdirAt request. @@ -761,25 +760,19 @@ func (m *MknodAtReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (m *MknodAtReq) MarshalBytes(dst []byte) { - m.createCommon.MarshalUnsafe(dst) - dst = dst[m.createCommon.SizeBytes():] - m.Name.MarshalBytes(dst) - dst = dst[m.Name.SizeBytes():] - m.Minor.MarshalUnsafe(dst) - dst = dst[m.Minor.SizeBytes():] - m.Major.MarshalUnsafe(dst) +func (m *MknodAtReq) MarshalBytes(dst []byte) []byte { + dst = m.createCommon.MarshalUnsafe(dst) + dst = m.Name.MarshalBytes(dst) + dst = m.Minor.MarshalUnsafe(dst) + return m.Major.MarshalUnsafe(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (m *MknodAtReq) UnmarshalBytes(src []byte) { - m.createCommon.UnmarshalUnsafe(src) - src = src[m.createCommon.SizeBytes():] - m.Name.UnmarshalBytes(src) - src = src[m.Name.SizeBytes():] - m.Minor.UnmarshalUnsafe(src) - src = src[m.Minor.SizeBytes():] - m.Major.UnmarshalUnsafe(src) +func (m *MknodAtReq) UnmarshalBytes(src []byte) []byte { + src = m.createCommon.UnmarshalUnsafe(src) + src = m.Name.UnmarshalBytes(src) + src = m.Minor.UnmarshalUnsafe(src) + return m.Major.UnmarshalUnsafe(src) } // MknodAtResp is the response to a successful MknodAt request. @@ -804,29 +797,21 @@ func (s *SymlinkAtReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (s *SymlinkAtReq) MarshalBytes(dst []byte) { - s.DirFD.MarshalUnsafe(dst) - dst = dst[s.DirFD.SizeBytes():] - s.Name.MarshalBytes(dst) - dst = dst[s.Name.SizeBytes():] - s.Target.MarshalBytes(dst) - dst = dst[s.Target.SizeBytes():] - s.UID.MarshalUnsafe(dst) - dst = dst[s.UID.SizeBytes():] - s.GID.MarshalUnsafe(dst) +func (s *SymlinkAtReq) MarshalBytes(dst []byte) []byte { + dst = s.DirFD.MarshalUnsafe(dst) + dst = s.Name.MarshalBytes(dst) + dst = s.Target.MarshalBytes(dst) + dst = s.UID.MarshalUnsafe(dst) + return s.GID.MarshalUnsafe(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (s *SymlinkAtReq) UnmarshalBytes(src []byte) { - s.DirFD.UnmarshalUnsafe(src) - src = src[s.DirFD.SizeBytes():] - s.Name.UnmarshalBytes(src) - src = src[s.Name.SizeBytes():] - s.Target.UnmarshalBytes(src) - src = src[s.Target.SizeBytes():] - s.UID.UnmarshalUnsafe(src) - src = src[s.UID.SizeBytes():] - s.GID.UnmarshalUnsafe(src) +func (s *SymlinkAtReq) UnmarshalBytes(src []byte) []byte { + src = s.DirFD.UnmarshalUnsafe(src) + src = s.Name.UnmarshalBytes(src) + src = s.Target.UnmarshalBytes(src) + src = s.UID.UnmarshalUnsafe(src) + return s.GID.UnmarshalUnsafe(src) } // SymlinkAtResp is the response to a successful SymlinkAt request. @@ -849,21 +834,17 @@ func (l *LinkAtReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (l *LinkAtReq) MarshalBytes(dst []byte) { - l.DirFD.MarshalUnsafe(dst) - dst = dst[l.DirFD.SizeBytes():] - l.Target.MarshalUnsafe(dst) - dst = dst[l.Target.SizeBytes():] - l.Name.MarshalBytes(dst) +func (l *LinkAtReq) MarshalBytes(dst []byte) []byte { + dst = l.DirFD.MarshalUnsafe(dst) + dst = l.Target.MarshalUnsafe(dst) + return l.Name.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (l *LinkAtReq) UnmarshalBytes(src []byte) { - l.DirFD.UnmarshalUnsafe(src) - src = src[l.DirFD.SizeBytes():] - l.Target.UnmarshalUnsafe(src) - src = src[l.Target.SizeBytes():] - l.Name.UnmarshalBytes(src) +func (l *LinkAtReq) UnmarshalBytes(src []byte) []byte { + src = l.DirFD.UnmarshalUnsafe(src) + src = l.Target.UnmarshalUnsafe(src) + return l.Name.UnmarshalBytes(src) } // LinkAtResp is used to respond to a successful LinkAt request. @@ -923,13 +904,13 @@ func (r *ReadLinkAtResp) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (r *ReadLinkAtResp) MarshalBytes(dst []byte) { - r.Target.MarshalBytes(dst) +func (r *ReadLinkAtResp) MarshalBytes(dst []byte) []byte { + return r.Target.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (r *ReadLinkAtResp) UnmarshalBytes(src []byte) { - r.Target.UnmarshalBytes(src) +func (r *ReadLinkAtResp) UnmarshalBytes(src []byte) []byte { + return r.Target.UnmarshalBytes(src) } // FlushReq is used to make Flush requests. @@ -963,21 +944,17 @@ func (u *UnlinkAtReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (u *UnlinkAtReq) MarshalBytes(dst []byte) { - u.DirFD.MarshalUnsafe(dst) - dst = dst[u.DirFD.SizeBytes():] - u.Name.MarshalBytes(dst) - dst = dst[u.Name.SizeBytes():] - u.Flags.MarshalUnsafe(dst) +func (u *UnlinkAtReq) MarshalBytes(dst []byte) []byte { + dst = u.DirFD.MarshalUnsafe(dst) + dst = u.Name.MarshalBytes(dst) + return u.Flags.MarshalUnsafe(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (u *UnlinkAtReq) UnmarshalBytes(src []byte) { - u.DirFD.UnmarshalUnsafe(src) - src = src[u.DirFD.SizeBytes():] - u.Name.UnmarshalBytes(src) - src = src[u.Name.SizeBytes():] - u.Flags.UnmarshalUnsafe(src) +func (u *UnlinkAtReq) UnmarshalBytes(src []byte) []byte { + src = u.DirFD.UnmarshalUnsafe(src) + src = u.Name.UnmarshalBytes(src) + return u.Flags.UnmarshalUnsafe(src) } // RenameAtReq is used to make Rename requests. Note that the request takes in @@ -994,21 +971,17 @@ func (r *RenameAtReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (r *RenameAtReq) MarshalBytes(dst []byte) { - r.Renamed.MarshalUnsafe(dst) - dst = dst[r.Renamed.SizeBytes():] - r.NewDir.MarshalUnsafe(dst) - dst = dst[r.NewDir.SizeBytes():] - r.NewName.MarshalBytes(dst) +func (r *RenameAtReq) MarshalBytes(dst []byte) []byte { + dst = r.Renamed.MarshalUnsafe(dst) + dst = r.NewDir.MarshalUnsafe(dst) + return r.NewName.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (r *RenameAtReq) UnmarshalBytes(src []byte) { - r.Renamed.UnmarshalUnsafe(src) - src = src[r.Renamed.SizeBytes():] - r.NewDir.UnmarshalUnsafe(src) - src = src[r.NewDir.SizeBytes():] - r.NewName.UnmarshalBytes(src) +func (r *RenameAtReq) UnmarshalBytes(src []byte) []byte { + src = r.Renamed.UnmarshalUnsafe(src) + src = r.NewDir.UnmarshalUnsafe(src) + return r.NewName.UnmarshalBytes(src) } // Getdents64Req is used to make Getdents64 requests. @@ -1039,33 +1012,23 @@ func (d *Dirent64) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (d *Dirent64) MarshalBytes(dst []byte) { - d.Ino.MarshalUnsafe(dst) - dst = dst[d.Ino.SizeBytes():] - d.DevMinor.MarshalUnsafe(dst) - dst = dst[d.DevMinor.SizeBytes():] - d.DevMajor.MarshalUnsafe(dst) - dst = dst[d.DevMajor.SizeBytes():] - d.Off.MarshalUnsafe(dst) - dst = dst[d.Off.SizeBytes():] - d.Type.MarshalUnsafe(dst) - dst = dst[d.Type.SizeBytes():] - d.Name.MarshalBytes(dst) +func (d *Dirent64) MarshalBytes(dst []byte) []byte { + dst = d.Ino.MarshalUnsafe(dst) + dst = d.DevMinor.MarshalUnsafe(dst) + dst = d.DevMajor.MarshalUnsafe(dst) + dst = d.Off.MarshalUnsafe(dst) + dst = d.Type.MarshalUnsafe(dst) + return d.Name.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (d *Dirent64) UnmarshalBytes(src []byte) { - d.Ino.UnmarshalUnsafe(src) - src = src[d.Ino.SizeBytes():] - d.DevMinor.UnmarshalUnsafe(src) - src = src[d.DevMinor.SizeBytes():] - d.DevMajor.UnmarshalUnsafe(src) - src = src[d.DevMajor.SizeBytes():] - d.Off.UnmarshalUnsafe(src) - src = src[d.Off.SizeBytes():] - d.Type.UnmarshalUnsafe(src) - src = src[d.Type.SizeBytes():] - d.Name.UnmarshalBytes(src) +func (d *Dirent64) UnmarshalBytes(src []byte) []byte { + src = d.Ino.UnmarshalUnsafe(src) + src = d.DevMinor.UnmarshalUnsafe(src) + src = d.DevMajor.UnmarshalUnsafe(src) + src = d.Off.UnmarshalUnsafe(src) + src = d.Type.UnmarshalUnsafe(src) + return d.Name.UnmarshalBytes(src) } // Getdents64Resp is used to communicate getdents64 results. @@ -1083,31 +1046,29 @@ func (g *Getdents64Resp) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (g *Getdents64Resp) MarshalBytes(dst []byte) { +func (g *Getdents64Resp) MarshalBytes(dst []byte) []byte { numDirents := primitive.Uint32(len(g.Dirents)) - numDirents.MarshalUnsafe(dst) - dst = dst[numDirents.SizeBytes():] + dst = numDirents.MarshalUnsafe(dst) for i := range g.Dirents { - g.Dirents[i].MarshalBytes(dst) - dst = dst[g.Dirents[i].SizeBytes():] + dst = g.Dirents[i].MarshalBytes(dst) } + return dst } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (g *Getdents64Resp) UnmarshalBytes(src []byte) { +func (g *Getdents64Resp) UnmarshalBytes(src []byte) []byte { var numDirents primitive.Uint32 - numDirents.UnmarshalUnsafe(src) + src = numDirents.UnmarshalUnsafe(src) if cap(g.Dirents) < int(numDirents) { g.Dirents = make([]Dirent64, numDirents) } else { g.Dirents = g.Dirents[:numDirents] } - src = src[numDirents.SizeBytes():] for i := range g.Dirents { - g.Dirents[i].UnmarshalBytes(src) - src = src[g.Dirents[i].SizeBytes():] + src = g.Dirents[i].UnmarshalBytes(src) } + return src } // FGetXattrReq is used to make FGetXattr requests. The response to this is @@ -1124,21 +1085,17 @@ func (g *FGetXattrReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (g *FGetXattrReq) MarshalBytes(dst []byte) { - g.FD.MarshalUnsafe(dst) - dst = dst[g.FD.SizeBytes():] - g.BufSize.MarshalUnsafe(dst) - dst = dst[g.BufSize.SizeBytes():] - g.Name.MarshalBytes(dst) +func (g *FGetXattrReq) MarshalBytes(dst []byte) []byte { + dst = g.FD.MarshalUnsafe(dst) + dst = g.BufSize.MarshalUnsafe(dst) + return g.Name.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (g *FGetXattrReq) UnmarshalBytes(src []byte) { - g.FD.UnmarshalUnsafe(src) - src = src[g.FD.SizeBytes():] - g.BufSize.UnmarshalUnsafe(src) - src = src[g.BufSize.SizeBytes():] - g.Name.UnmarshalBytes(src) +func (g *FGetXattrReq) UnmarshalBytes(src []byte) []byte { + src = g.FD.UnmarshalUnsafe(src) + src = g.BufSize.UnmarshalUnsafe(src) + return g.Name.UnmarshalBytes(src) } // FGetXattrResp is used to respond to FGetXattr request. @@ -1152,13 +1109,13 @@ func (g *FGetXattrResp) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (g *FGetXattrResp) MarshalBytes(dst []byte) { - g.Value.MarshalBytes(dst) +func (g *FGetXattrResp) MarshalBytes(dst []byte) []byte { + return g.Value.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (g *FGetXattrResp) UnmarshalBytes(src []byte) { - g.Value.UnmarshalBytes(src) +func (g *FGetXattrResp) UnmarshalBytes(src []byte) []byte { + return g.Value.UnmarshalBytes(src) } // FSetXattrReq is used to make FSetXattr requests. It has no response. @@ -1175,25 +1132,19 @@ func (s *FSetXattrReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (s *FSetXattrReq) MarshalBytes(dst []byte) { - s.FD.MarshalUnsafe(dst) - dst = dst[s.FD.SizeBytes():] - s.Flags.MarshalUnsafe(dst) - dst = dst[s.Flags.SizeBytes():] - s.Name.MarshalBytes(dst) - dst = dst[s.Name.SizeBytes():] - s.Value.MarshalBytes(dst) +func (s *FSetXattrReq) MarshalBytes(dst []byte) []byte { + dst = s.FD.MarshalUnsafe(dst) + dst = s.Flags.MarshalUnsafe(dst) + dst = s.Name.MarshalBytes(dst) + return s.Value.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (s *FSetXattrReq) UnmarshalBytes(src []byte) { - s.FD.UnmarshalUnsafe(src) - src = src[s.FD.SizeBytes():] - s.Flags.UnmarshalUnsafe(src) - src = src[s.Flags.SizeBytes():] - s.Name.UnmarshalBytes(src) - src = src[s.Name.SizeBytes():] - s.Value.UnmarshalBytes(src) +func (s *FSetXattrReq) UnmarshalBytes(src []byte) []byte { + src = s.FD.UnmarshalUnsafe(src) + src = s.Flags.UnmarshalUnsafe(src) + src = s.Name.UnmarshalBytes(src) + return s.Value.UnmarshalBytes(src) } // FRemoveXattrReq is used to make FRemoveXattr requests. It has no response. @@ -1208,17 +1159,15 @@ func (r *FRemoveXattrReq) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (r *FRemoveXattrReq) MarshalBytes(dst []byte) { - r.FD.MarshalUnsafe(dst) - dst = dst[r.FD.SizeBytes():] - r.Name.MarshalBytes(dst) +func (r *FRemoveXattrReq) MarshalBytes(dst []byte) []byte { + dst = r.FD.MarshalUnsafe(dst) + return r.Name.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (r *FRemoveXattrReq) UnmarshalBytes(src []byte) { - r.FD.UnmarshalUnsafe(src) - src = src[r.FD.SizeBytes():] - r.Name.UnmarshalBytes(src) +func (r *FRemoveXattrReq) UnmarshalBytes(src []byte) []byte { + src = r.FD.UnmarshalUnsafe(src) + return r.Name.UnmarshalBytes(src) } // FListXattrReq is used to make FListXattr requests. @@ -1241,11 +1190,11 @@ func (l *FListXattrResp) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (l *FListXattrResp) MarshalBytes(dst []byte) { - l.Xattrs.MarshalBytes(dst) +func (l *FListXattrResp) MarshalBytes(dst []byte) []byte { + return l.Xattrs.MarshalBytes(dst) } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (l *FListXattrResp) UnmarshalBytes(src []byte) { - l.Xattrs.UnmarshalBytes(src) +func (l *FListXattrResp) UnmarshalBytes(src []byte) []byte { + return l.Xattrs.UnmarshalBytes(src) } diff --git a/pkg/lisafs/sample_message.go b/pkg/lisafs/sample_message.go index 3868dfa08..745736b6d 100644 --- a/pkg/lisafs/sample_message.go +++ b/pkg/lisafs/sample_message.go @@ -53,18 +53,24 @@ func (m *MsgDynamic) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (m *MsgDynamic) MarshalBytes(dst []byte) { - m.N.MarshalUnsafe(dst) - dst = dst[m.N.SizeBytes():] - MarshalUnsafeMsg1Slice(m.Arr, dst) +func (m *MsgDynamic) MarshalBytes(dst []byte) []byte { + dst = m.N.MarshalUnsafe(dst) + n, err := MarshalUnsafeMsg1Slice(m.Arr, dst) + if err != nil { + panic(err) + } + return dst[n:] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (m *MsgDynamic) UnmarshalBytes(src []byte) { - m.N.UnmarshalUnsafe(src) - src = src[m.N.SizeBytes():] +func (m *MsgDynamic) UnmarshalBytes(src []byte) []byte { + src = m.N.UnmarshalUnsafe(src) m.Arr = make([]MsgSimple, m.N) - UnmarshalUnsafeMsg1Slice(m.Arr, src) + n, err := UnmarshalUnsafeMsg1Slice(m.Arr, src) + if err != nil { + panic(err) + } + return src[n:] } // Randomize randomizes the contents of m. @@ -90,21 +96,18 @@ func (v *P9Version) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (v *P9Version) MarshalBytes(dst []byte) { - v.MSize.MarshalUnsafe(dst) - dst = dst[v.MSize.SizeBytes():] +func (v *P9Version) MarshalBytes(dst []byte) []byte { + dst = v.MSize.MarshalUnsafe(dst) versionLen := primitive.Uint16(len(v.Version)) - versionLen.MarshalUnsafe(dst) - dst = dst[versionLen.SizeBytes():] - copy(dst, v.Version) + dst = versionLen.MarshalUnsafe(dst) + return dst[copy(dst, v.Version):] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (v *P9Version) UnmarshalBytes(src []byte) { - v.MSize.UnmarshalUnsafe(src) - src = src[v.MSize.SizeBytes():] +func (v *P9Version) UnmarshalBytes(src []byte) []byte { + src = v.MSize.UnmarshalUnsafe(src) var versionLen primitive.Uint16 - versionLen.UnmarshalUnsafe(src) - src = src[versionLen.SizeBytes():] + src = versionLen.UnmarshalUnsafe(src) v.Version = string(src[:versionLen]) + return src[versionLen:] } diff --git a/pkg/marshal/marshal.go b/pkg/marshal/marshal.go index 7da450ce8..9e34eae80 100644 --- a/pkg/marshal/marshal.go +++ b/pkg/marshal/marshal.go @@ -59,13 +59,15 @@ type Marshallable interface { // likely make use of the type of these fields). SizeBytes() int - // MarshalBytes serializes a copy of a type to dst. + // MarshalBytes serializes a copy of a type to dst and returns the remaining + // buffer. // Precondition: dst must be at least SizeBytes() in length. - MarshalBytes(dst []byte) + MarshalBytes(dst []byte) []byte - // UnmarshalBytes deserializes a type from src. + // UnmarshalBytes deserializes a type from src and returns the remaining + // buffer. // Precondition: src must be at least SizeBytes() in length. - UnmarshalBytes(src []byte) + UnmarshalBytes(src []byte) []byte // Packed returns true if the marshalled size of the type is the same as the // size it occupies in memory. This happens when the type has no fields @@ -86,7 +88,7 @@ type Marshallable interface { // return false, MarshalUnsafe should fall back to the safer but slower // MarshalBytes. // Precondition: dst must be at least SizeBytes() in length. - MarshalUnsafe(dst []byte) + MarshalUnsafe(dst []byte) []byte // UnmarshalUnsafe deserializes a type by directly copying to the underlying // memory allocated for the object by the runtime. @@ -96,7 +98,7 @@ type Marshallable interface { // UnmarshalUnsafe should fall back to the safer but slower unmarshal // mechanism implemented in UnmarshalBytes. // Precondition: src must be at least SizeBytes() in length. - UnmarshalUnsafe(src []byte) + UnmarshalUnsafe(src []byte) []byte // CopyIn deserializes a Marshallable type from a task's memory. This may // only be called from a task goroutine. This is more efficient than calling diff --git a/pkg/marshal/marshal_impl_util.go b/pkg/marshal/marshal_impl_util.go index 9e6a6fa29..6c1cf7a4c 100644 --- a/pkg/marshal/marshal_impl_util.go +++ b/pkg/marshal/marshal_impl_util.go @@ -38,12 +38,12 @@ func (StubMarshallable) SizeBytes() int { } // MarshalBytes implements Marshallable.MarshalBytes. -func (StubMarshallable) MarshalBytes(dst []byte) { +func (StubMarshallable) MarshalBytes(dst []byte) []byte { panic("Please implement your own MarshalBytes function") } // UnmarshalBytes implements Marshallable.UnmarshalBytes. -func (StubMarshallable) UnmarshalBytes(src []byte) { +func (StubMarshallable) UnmarshalBytes(src []byte) []byte { panic("Please implement your own UnmarshalBytes function") } @@ -53,12 +53,12 @@ func (StubMarshallable) Packed() bool { } // MarshalUnsafe implements Marshallable.MarshalUnsafe. -func (StubMarshallable) MarshalUnsafe(dst []byte) { +func (StubMarshallable) MarshalUnsafe(dst []byte) []byte { panic("Please implement your own MarshalUnsafe function") } // UnmarshalUnsafe implements Marshallable.UnmarshalUnsafe. -func (StubMarshallable) UnmarshalUnsafe(src []byte) { +func (StubMarshallable) UnmarshalUnsafe(src []byte) []byte { panic("Please implement your own UnmarshalUnsafe function") } diff --git a/pkg/marshal/primitive/primitive.go b/pkg/marshal/primitive/primitive.go index 1c49cf082..7ece26933 100644 --- a/pkg/marshal/primitive/primitive.go +++ b/pkg/marshal/primitive/primitive.go @@ -76,13 +76,13 @@ func (b *ByteSlice) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (b *ByteSlice) MarshalBytes(dst []byte) { - copy(dst, *b) +func (b *ByteSlice) MarshalBytes(dst []byte) []byte { + return dst[copy(dst, *b):] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (b *ByteSlice) UnmarshalBytes(src []byte) { - copy(*b, src) +func (b *ByteSlice) UnmarshalBytes(src []byte) []byte { + return src[copy(*b, src):] } // Packed implements marshal.Marshallable.Packed. @@ -91,13 +91,13 @@ func (b *ByteSlice) Packed() bool { } // MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. -func (b *ByteSlice) MarshalUnsafe(dst []byte) { - b.MarshalBytes(dst) +func (b *ByteSlice) MarshalUnsafe(dst []byte) []byte { + return b.MarshalBytes(dst) } // UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. -func (b *ByteSlice) UnmarshalUnsafe(src []byte) { - b.UnmarshalBytes(src) +func (b *ByteSlice) UnmarshalUnsafe(src []byte) []byte { + return b.UnmarshalBytes(src) } // CopyIn implements marshal.Marshallable.CopyIn. diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 05c4fbeb2..18497a880 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -77,8 +77,7 @@ go_test( deps = [ "//pkg/abi/linux", "//pkg/errors/linuxerr", - "//pkg/hostarch", - "//pkg/marshal", + "//pkg/marshal/primitive", "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/fuse/connection_test.go b/pkg/sentry/fsimpl/fuse/connection_test.go index 1fddd858e..d98d2832b 100644 --- a/pkg/sentry/fsimpl/fuse/connection_test.go +++ b/pkg/sentry/fsimpl/fuse/connection_test.go @@ -20,6 +20,7 @@ import ( "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ) @@ -69,14 +70,10 @@ func TestConnectionAbort(t *testing.T) { t.Fatalf("newTestConnection: %v", err) } - testObj := &testPayload{ - data: rand.Uint32(), - } - var futNormal []*futureResponse - + testObj := primitive.Uint32(rand.Uint32()) for i := 0; i < int(numRequests); i++ { - req := conn.NewRequest(creds, uint32(i), uint64(i), 0, testObj) + req := conn.NewRequest(creds, uint32(i), uint64(i), 0, &testObj) fut, err := conn.callFutureLocked(task, req) if err != nil { t.Fatalf("callFutureLocked failed: %v", err) @@ -102,7 +99,7 @@ func TestConnectionAbort(t *testing.T) { } // After abort, Call() should return directly with ENOTCONN. - req := conn.NewRequest(creds, 0, 0, 0, testObj) + req := conn.NewRequest(creds, 0, 0, 0, &testObj) _, err = conn.Call(task, req) if !linuxerr.Equals(linuxerr.ENOTCONN, err) { t.Fatalf("Incorrect error code received for Call() after connection aborted") diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go index 8951b5ba8..13b32fc7c 100644 --- a/pkg/sentry/fsimpl/fuse/dev_test.go +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -21,6 +21,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -215,11 +216,9 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con if err != nil { t.Fatal(err) } - testObj := &testPayload{ - data: rand.Uint32(), - } - req := conn.NewRequest(creds, pid, inode, echoTestOpcode, testObj) + testObj := primitive.Uint32(rand.Uint32()) + req := conn.NewRequest(creds, pid, inode, echoTestOpcode, &testObj) // Queue up a request. // Analogous to Call except it doesn't block on the task. @@ -232,7 +231,7 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con t.Fatalf("Server responded with an error: %v", err) } - var respTestPayload testPayload + var respTestPayload primitive.Uint32 if err := resp.UnmarshalPayload(&respTestPayload); err != nil { t.Fatalf("Unmarshalling payload error: %v", err) } @@ -242,8 +241,8 @@ func fuseClientRun(t *testing.T, s *testutil.System, k *kernel.Kernel, conn *con req.hdr.Unique, resp.hdr.Unique) } - if respTestPayload.data != testObj.data { - t.Fatalf("read incorrect data. Data expected: %v, but got %v", testObj.data, respTestPayload.data) + if respTestPayload != testObj { + t.Fatalf("read incorrect data. Data expected: %d, but got %d", testObj, respTestPayload) } } @@ -256,8 +255,8 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F // Create the tasks that the server will be using. tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) - var readPayload testPayload + var readPayload primitive.Uint32 serverTask, err := testutil.CreateTask(s.Ctx, "fuse-server", tc, s.MntNs, s.Root, s.Root) if err != nil { t.Fatal(err) @@ -291,8 +290,8 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F } var readFUSEHeaderIn linux.FUSEHeaderIn - readFUSEHeaderIn.UnmarshalUnsafe(inBuf[:inHdrLen]) - readPayload.UnmarshalUnsafe(inBuf[inHdrLen : inHdrLen+payloadLen]) + inBuf = readFUSEHeaderIn.UnmarshalUnsafe(inBuf) + readPayload.UnmarshalUnsafe(inBuf) if readFUSEHeaderIn.Opcode != echoTestOpcode { t.Fatalf("read incorrect data. Header: %v, Payload: %v", readFUSEHeaderIn, readPayload) diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index af16098d2..00b520c31 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -489,7 +489,7 @@ func (i *inode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentr // Lookup implements kernfs.Inode.Lookup. func (i *inode) Lookup(ctx context.Context, name string) (kernfs.Inode, error) { - in := linux.FUSELookupIn{Name: name} + in := linux.FUSELookupIn{Name: linux.CString(name)} return i.newEntry(ctx, name, 0, linux.FUSE_LOOKUP, &in) } @@ -520,7 +520,7 @@ func (i *inode) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) Mode: uint32(opts.Mode) | linux.S_IFREG, Umask: uint32(kernelTask.FSContext().Umask()), }, - Name: name, + Name: linux.CString(name), } return i.newEntry(ctx, name, linux.S_IFREG, linux.FUSE_CREATE, &in) } @@ -533,7 +533,7 @@ func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) Rdev: linux.MakeDeviceID(uint16(opts.DevMajor), opts.DevMinor), Umask: uint32(kernel.TaskFromContext(ctx).FSContext().Umask()), }, - Name: name, + Name: linux.CString(name), } return i.newEntry(ctx, name, opts.Mode.FileType(), linux.FUSE_MKNOD, &in) } @@ -541,8 +541,8 @@ func (i *inode) NewNode(ctx context.Context, name string, opts vfs.MknodOptions) // NewSymlink implements kernfs.Inode.NewSymlink. func (i *inode) NewSymlink(ctx context.Context, name, target string) (kernfs.Inode, error) { in := linux.FUSESymLinkIn{ - Name: name, - Target: target, + Name: linux.CString(name), + Target: linux.CString(target), } return i.newEntry(ctx, name, linux.S_IFLNK, linux.FUSE_SYMLINK, &in) } @@ -554,7 +554,7 @@ func (i *inode) Unlink(ctx context.Context, name string, child kernfs.Inode) err log.Warningf("fusefs.Inode.newEntry: couldn't get kernel task from context", i.nodeID) return linuxerr.EINVAL } - in := linux.FUSEUnlinkIn{Name: name} + in := linux.FUSEUnlinkIn{Name: linux.CString(name)} req := i.fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(kernelTask.ThreadID()), i.nodeID, linux.FUSE_UNLINK, &in) res, err := i.fs.conn.Call(kernelTask, req) if err != nil { @@ -571,7 +571,7 @@ func (i *inode) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) Mode: uint32(opts.Mode), Umask: uint32(kernel.TaskFromContext(ctx).FSContext().Umask()), }, - Name: name, + Name: linux.CString(name), } return i.newEntry(ctx, name, linux.S_IFDIR, linux.FUSE_MKDIR, &in) } @@ -581,7 +581,7 @@ func (i *inode) RmDir(ctx context.Context, name string, child kernfs.Inode) erro fusefs := i.fs task, creds := kernel.TaskFromContext(ctx), auth.CredentialsFromContext(ctx) - in := linux.FUSERmDirIn{Name: name} + in := linux.FUSERmDirIn{Name: linux.CString(name)} req := fusefs.conn.NewRequest(creds, uint32(task.ThreadID()), i.nodeID, linux.FUSE_RMDIR, &in) res, err := i.fs.conn.Call(task, req) if err != nil { diff --git a/pkg/sentry/fsimpl/fuse/request_response.go b/pkg/sentry/fsimpl/fuse/request_response.go index 8a72489fa..ec76ec2a4 100644 --- a/pkg/sentry/fsimpl/fuse/request_response.go +++ b/pkg/sentry/fsimpl/fuse/request_response.go @@ -41,7 +41,7 @@ type fuseInitRes struct { } // UnmarshalBytes deserializes src to the initOut attribute in a fuseInitRes. -func (r *fuseInitRes) UnmarshalBytes(src []byte) { +func (r *fuseInitRes) UnmarshalBytes(src []byte) []byte { out := &r.initOut // Introduced before FUSE kernel version 7.13. @@ -70,7 +70,7 @@ func (r *fuseInitRes) UnmarshalBytes(src []byte) { out.MaxPages = uint16(hostarch.ByteOrder.Uint16(src[:2])) src = src[2:] } - _ = src // Remove unused warning. + return src } // SizeBytes is the size of the payload of the FUSE_INIT response. diff --git a/pkg/sentry/fsimpl/fuse/utils_test.go b/pkg/sentry/fsimpl/fuse/utils_test.go index b0bab0066..8d4a2fad3 100644 --- a/pkg/sentry/fsimpl/fuse/utils_test.go +++ b/pkg/sentry/fsimpl/fuse/utils_test.go @@ -15,17 +15,13 @@ package fuse import ( - "io" "testing" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/marshal" "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" - - "gvisor.dev/gvisor/pkg/hostarch" ) func setup(t *testing.T) *testutil.System { @@ -70,58 +66,3 @@ func newTestConnection(system *testutil.System, k *kernel.Kernel, maxActiveReque } return fs.conn, &fuseDev.vfsfd, nil } - -type testPayload struct { - marshal.StubMarshallable - data uint32 -} - -// SizeBytes implements marshal.Marshallable.SizeBytes. -func (t *testPayload) SizeBytes() int { - return 4 -} - -// MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (t *testPayload) MarshalBytes(dst []byte) { - hostarch.ByteOrder.PutUint32(dst[:4], t.data) -} - -// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (t *testPayload) UnmarshalBytes(src []byte) { - *t = testPayload{data: hostarch.ByteOrder.Uint32(src[:4])} -} - -// Packed implements marshal.Marshallable.Packed. -func (t *testPayload) Packed() bool { - return true -} - -// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe. -func (t *testPayload) MarshalUnsafe(dst []byte) { - t.MarshalBytes(dst) -} - -// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe. -func (t *testPayload) UnmarshalUnsafe(src []byte) { - t.UnmarshalBytes(src) -} - -// CopyOutN implements marshal.Marshallable.CopyOutN. -func (t *testPayload) CopyOutN(task marshal.CopyContext, addr hostarch.Addr, limit int) (int, error) { - panic("not implemented") -} - -// CopyOut implements marshal.Marshallable.CopyOut. -func (t *testPayload) CopyOut(task marshal.CopyContext, addr hostarch.Addr) (int, error) { - panic("not implemented") -} - -// CopyIn implements marshal.Marshallable.CopyIn. -func (t *testPayload) CopyIn(task marshal.CopyContext, addr hostarch.Addr) (int, error) { - panic("not implemented") -} - -// WriteTo implements io.WriterTo.WriteTo. -func (t *testPayload) WriteTo(w io.Writer) (int64, error) { - panic("not implemented") -} diff --git a/pkg/sentry/loader/elf.go b/pkg/sentry/loader/elf.go index fb213d109..09b148164 100644 --- a/pkg/sentry/loader/elf.go +++ b/pkg/sentry/loader/elf.go @@ -212,8 +212,7 @@ func parseHeader(ctx context.Context, f fullReader) (elfInfo, error) { phdrs := make([]elf.ProgHeader, hdr.Phnum) for i := range phdrs { var prog64 linux.ElfProg64 - prog64.UnmarshalUnsafe(phdrBuf[:prog64Size]) - phdrBuf = phdrBuf[prog64Size:] + phdrBuf = prog64.UnmarshalUnsafe(phdrBuf) phdrs[i] = elf.ProgHeader{ Type: elf.ProgType(prog64.Type), Flags: elf.ProgFlag(prog64.Flags), diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 6077b2150..4b036b323 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -17,6 +17,8 @@ package control import ( + "time" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/context" @@ -29,7 +31,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "time" ) // SCMCredentials represents a SCM_CREDENTIALS socket control message. @@ -63,10 +64,10 @@ type RightsFiles []*fs.File // NewSCMRights creates a new SCM_RIGHTS socket control message representation // using local sentry FDs. -func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) { +func NewSCMRights(t *kernel.Task, fds []primitive.Int32) (SCMRights, error) { files := make(RightsFiles, 0, len(fds)) for _, fd := range fds { - file := t.GetFile(fd) + file := t.GetFile(int32(fd)) if file == nil { files.Release(t) return nil, linuxerr.EBADF @@ -486,26 +487,25 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) (socket.ControlMessages, error) { var ( cmsgs socket.ControlMessages - fds linux.ControlMessageRights + fds []primitive.Int32 ) - for i := 0; i < len(buf); { - if i+linux.SizeOfControlMessageHeader > len(buf) { + for len(buf) > 0 { + if linux.SizeOfControlMessageHeader > len(buf) { return cmsgs, linuxerr.EINVAL } var h linux.ControlMessageHeader - h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) + buf = h.UnmarshalUnsafe(buf) if h.Length < uint64(linux.SizeOfControlMessageHeader) { return socket.ControlMessages{}, linuxerr.EINVAL } - if h.Length > uint64(len(buf)-i) { - return socket.ControlMessages{}, linuxerr.EINVAL - } - i += linux.SizeOfControlMessageHeader length := int(h.Length) - linux.SizeOfControlMessageHeader + if length > len(buf) { + return socket.ControlMessages{}, linuxerr.EINVAL + } switch h.Level { case linux.SOL_SOCKET: @@ -518,11 +518,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, linuxerr.EINVAL } - for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight { - fds = append(fds, int32(hostarch.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) - } - - i += bits.AlignUp(length, width) + curFDs := make([]primitive.Int32, numRights) + primitive.UnmarshalUnsafeInt32Slice(curFDs, buf[:rightsSize]) + fds = append(fds, curFDs...) case linux.SCM_CREDENTIALS: if length < linux.SizeOfControlMessageCredentials { @@ -530,23 +528,21 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var creds linux.ControlMessageCredentials - creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) + creds.UnmarshalUnsafe(buf) scmCreds, err := NewSCMCredentials(t, creds) if err != nil { return socket.ControlMessages{}, err } cmsgs.Unix.Credentials = scmCreds - i += bits.AlignUp(length, width) case linux.SO_TIMESTAMP: if length < linux.SizeOfTimeval { return socket.ControlMessages{}, linuxerr.EINVAL } var ts linux.Timeval - ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) + ts.UnmarshalUnsafe(buf) cmsgs.IP.Timestamp = ts.ToTime() cmsgs.IP.HasTimestamp = true - i += bits.AlignUp(length, width) default: // Unknown message type. @@ -560,9 +556,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } cmsgs.IP.HasTOS = true var tos primitive.Uint8 - tos.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTOS]) + tos.UnmarshalUnsafe(buf) cmsgs.IP.TOS = uint8(tos) - i += bits.AlignUp(length, width) case linux.IP_PKTINFO: if length < linux.SizeOfControlMessageIPPacketInfo { @@ -571,19 +566,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) cmsgs.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - packetInfo.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageIPPacketInfo]) - + packetInfo.UnmarshalUnsafe(buf) cmsgs.IP.PacketInfo = packetInfo - i += bits.AlignUp(length, width) case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet if length < addr.SizeBytes() { return socket.ControlMessages{}, linuxerr.EINVAL } - addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) + addr.UnmarshalUnsafe(buf) cmsgs.IP.OriginalDstAddress = &addr - i += bits.AlignUp(length, width) case linux.IP_RECVERR: var errCmsg linux.SockErrCMsgIPv4 @@ -591,9 +583,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, linuxerr.EINVAL } - errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + errCmsg.UnmarshalBytes(buf) cmsgs.IP.SockErr = &errCmsg - i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, linuxerr.EINVAL @@ -606,18 +597,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } cmsgs.IP.HasTClass = true var tclass primitive.Uint32 - tclass.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTClass]) + tclass.UnmarshalUnsafe(buf) cmsgs.IP.TClass = uint32(tclass) - i += bits.AlignUp(length, width) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 if length < addr.SizeBytes() { return socket.ControlMessages{}, linuxerr.EINVAL } - addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) + addr.UnmarshalUnsafe(buf) cmsgs.IP.OriginalDstAddress = &addr - i += bits.AlignUp(length, width) case linux.IPV6_RECVERR: var errCmsg linux.SockErrCMsgIPv6 @@ -625,9 +614,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, linuxerr.EINVAL } - errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + errCmsg.UnmarshalBytes(buf) cmsgs.IP.SockErr = &errCmsg - i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, linuxerr.EINVAL @@ -635,6 +623,11 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) default: return socket.ControlMessages{}, linuxerr.EINVAL } + if shift := bits.AlignUp(length, width); shift > len(buf) { + buf = buf[:0] + } else { + buf = buf[shift:] + } } if cmsgs.Unix.Credentials == nil { diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go index 0a989cbeb..a638cb955 100644 --- a/pkg/sentry/socket/control/control_vfs2.go +++ b/pkg/sentry/socket/control/control_vfs2.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -45,10 +46,10 @@ type RightsFilesVFS2 []*vfs.FileDescription // NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message // representation using local sentry FDs. -func NewSCMRightsVFS2(t *kernel.Task, fds []int32) (SCMRightsVFS2, error) { +func NewSCMRightsVFS2(t *kernel.Task, fds []primitive.Int32) (SCMRightsVFS2, error) { files := make(RightsFilesVFS2, 0, len(fds)) for _, fd := range fds { - file := t.GetFileVFS2(fd) + file := t.GetFileVFS2(int32(fd)) if file == nil { files.Release(t) return nil, linuxerr.EBADF diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 6e2318f75..a31f3ebec 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -578,7 +578,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.SO_TIMESTAMP: controlMessages.IP.HasTimestamp = true ts := linux.Timeval{} - ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval]) + ts.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.Timestamp = ts.ToTime() } @@ -587,18 +587,18 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.IP_TOS: controlMessages.IP.HasTOS = true var tos primitive.Uint8 - tos.UnmarshalUnsafe(unixCmsg.Data[:tos.SizeBytes()]) + tos.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.TOS = uint8(tos) case linux.IP_PKTINFO: controlMessages.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - packetInfo.UnmarshalUnsafe(unixCmsg.Data[:packetInfo.SizeBytes()]) + packetInfo.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.PacketInfo = packetInfo case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet - addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.OriginalDstAddress = &addr case unix.IP_RECVERR: @@ -612,12 +612,12 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.IPV6_TCLASS: controlMessages.IP.HasTClass = true var tclass primitive.Uint32 - tclass.UnmarshalUnsafe(unixCmsg.Data[:tclass.SizeBytes()]) + tclass.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.TClass = uint32(tclass) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 - addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.OriginalDstAddress = &addr case unix.IPV6_RECVERR: @@ -631,7 +631,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.TCP_INQ: controlMessages.IP.HasInq = true var inq primitive.Int32 - inq.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfControlMessageInq]) + inq.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.Inq = int32(inq) } } diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 61111ac6c..c84ab3fb7 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -138,7 +138,7 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), unix.SizeofIfInfomsg) } var ifinfo linux.InterfaceInfoMessage - ifinfo.UnmarshalUnsafe(link.Data[:ifinfo.SizeBytes()]) + ifinfo.UnmarshalUnsafe(link.Data) inetIF := inet.Interface{ DeviceType: ifinfo.Type, Flags: ifinfo.Flags, @@ -169,7 +169,7 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), unix.SizeofIfAddrmsg) } var ifaddr linux.InterfaceAddrMessage - ifaddr.UnmarshalUnsafe(addr.Data[:ifaddr.SizeBytes()]) + ifaddr.UnmarshalUnsafe(addr.Data) inetAddr := inet.InterfaceAddr{ Family: ifaddr.Family, PrefixLen: ifaddr.PrefixLen, @@ -201,7 +201,7 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) } var ifRoute linux.RouteMessage - ifRoute.UnmarshalUnsafe(routeMsg.Data[:ifRoute.SizeBytes()]) + ifRoute.UnmarshalUnsafe(routeMsg.Data) inetRoute := inet.Route{ Family: ifRoute.Family, DstLen: ifRoute.DstLen, diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 3f1b4a17b..7606d2bbb 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -80,9 +80,8 @@ func marshalEntryMatch(name string, data []byte) []byte { copy(matcher.Name[:], name) buf := make([]byte, size) - entryLen := matcher.XTEntryMatch.SizeBytes() - matcher.XTEntryMatch.MarshalUnsafe(buf[:entryLen]) - copy(buf[entryLen:], matcher.Data) + bufRemain := matcher.XTEntryMatch.MarshalUnsafe(buf) + copy(bufRemain, matcher.Data) return buf } diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index af31cbc5b..6cbfee8b6 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -141,10 +141,9 @@ func modifyEntries4(task *kernel.Task, stk *stack.Stack, optVal []byte, replace nflog("optVal has insufficient size for entry %d", len(optVal)) return nil, syserr.ErrInvalidArgument } - var entry linux.IPTEntry - entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()]) initialOptValLen := len(optVal) - optVal = optVal[entry.SizeBytes():] + var entry linux.IPTEntry + optVal = entry.UnmarshalUnsafe(optVal) if entry.TargetOffset < linux.SizeOfIPTEntry { nflog("entry has too-small target offset %d", entry.TargetOffset) diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 6cefe0b9c..902707abf 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -144,10 +144,9 @@ func modifyEntries6(task *kernel.Task, stk *stack.Stack, optVal []byte, replace nflog("optVal has insufficient size for entry %d", len(optVal)) return nil, syserr.ErrInvalidArgument } - var entry linux.IP6TEntry - entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()]) initialOptValLen := len(optVal) - optVal = optVal[entry.SizeBytes():] + var entry linux.IP6TEntry + optVal = entry.UnmarshalUnsafe(optVal) if entry.TargetOffset < linux.SizeOfIP6TEntry { nflog("entry has too-small target offset %d", entry.TargetOffset) diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 01f2f8c77..f2cdf091d 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -176,9 +176,7 @@ func setHooksAndUnderflow(info *linux.IPTGetinfo, table stack.Table, offset uint // net/ipv4/netfilter/ip_tables.c:translate_table for reference. func SetEntries(task *kernel.Task, stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { var replace linux.IPTReplace - replaceBuf := optVal[:linux.SizeOfIPTReplace] - optVal = optVal[linux.SizeOfIPTReplace:] - replace.UnmarshalBytes(replaceBuf) + optVal = replace.UnmarshalBytes(optVal) var table stack.Table switch replace.Name.String() { @@ -306,8 +304,7 @@ func parseMatchers(task *kernel.Task, filter stack.IPHeaderFilter, optVal []byte return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal)) } var match linux.XTEntryMatch - buf := optVal[:match.SizeBytes()] - match.UnmarshalUnsafe(buf) + match.UnmarshalUnsafe(optVal) nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match) // Check some invariants. diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index 6eff2ae65..d83d7e535 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -73,7 +73,7 @@ func (ownerMarshaler) unmarshal(task *kernel.Task, buf []byte, filter stack.IPHe // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.IPTOwnerInfo - matchData.UnmarshalUnsafe(buf[:linux.SizeOfIPTOwnerInfo]) + matchData.UnmarshalUnsafe(buf) nflog("parsed IPTOwnerInfo: %+v", matchData) var owner OwnerMatcher diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index b9c15daab..eaf601543 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -199,7 +199,7 @@ func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } var standardTarget linux.XTStandardTarget - standardTarget.UnmarshalUnsafe(buf[:standardTarget.SizeBytes()]) + standardTarget.UnmarshalUnsafe(buf) if standardTarget.Verdict < 0 { // A Verdict < 0 indicates a non-jump verdict. @@ -253,7 +253,6 @@ func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar return nil, syserr.ErrInvalidArgument } var errTgt linux.XTErrorTarget - buf = buf[:linux.SizeOfXTErrorTarget] errTgt.UnmarshalUnsafe(buf) // Error targets are used in 2 cases: @@ -316,7 +315,6 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( } var rt linux.XTRedirectTarget - buf = buf[:linux.SizeOfXTRedirectTarget] rt.UnmarshalUnsafe(buf) // Copy linux.XTRedirectTarget to stack.RedirectTarget. @@ -405,8 +403,7 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar } var natRange linux.NFNATRange - buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] - natRange.UnmarshalUnsafe(buf) + natRange.UnmarshalUnsafe(buf[linux.SizeOfXTEntryTarget:]) // We don't support port or address ranges. if natRange.MinAddr != natRange.MaxAddr { @@ -477,7 +474,6 @@ func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta } var st linux.XTSNATTarget - buf = buf[:linux.SizeOfXTSNATTarget] st.UnmarshalUnsafe(buf) // Copy linux.XTSNATTarget to stack.SNATTarget. @@ -557,8 +553,7 @@ func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta } var natRange linux.NFNATRange - buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] - natRange.UnmarshalUnsafe(buf) + natRange.UnmarshalUnsafe(buf[linux.SizeOfXTEntryTarget:]) // TODO(gvisor.dev/issue/5697): Support port or address ranges. if natRange.MinAddr != natRange.MaxAddr { @@ -621,7 +616,9 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.T return nil, syserr.ErrInvalidArgument } var target linux.XTEntryTarget - target.UnmarshalUnsafe(optVal[:target.SizeBytes()]) + // Do not advance optVal as targetMake.unmarshal() may unmarshal + // XTEntryTarget again but with some added fields. + target.UnmarshalUnsafe(optVal) return unmarshalTarget(target, filter, optVal) } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index e5b73a976..a621a6a16 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -59,7 +59,7 @@ func (tcpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderF // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.XTTCP - matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()]) + matchData.UnmarshalUnsafe(buf) nflog("parseMatchers: parsed XTTCP: %+v", matchData) if matchData.Option != 0 || diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index aa72ee70c..2ca854764 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -59,7 +59,7 @@ func (udpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderF // For alignment reasons, the match's total size may exceed what's // strictly necessary to hold matchData. var matchData linux.XTUDP - matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()]) + matchData.UnmarshalUnsafe(buf) nflog("parseMatchers: parsed XTUDP: %+v", matchData) if matchData.InverseFlags != 0 { diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go index 968968469..1604b2792 100644 --- a/pkg/sentry/socket/netlink/message_test.go +++ b/pkg/sentry/socket/netlink/message_test.go @@ -25,33 +25,14 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netlink" ) -type dummyNetlinkMsg struct { - marshal.StubMarshallable - Foo uint16 -} - -func (*dummyNetlinkMsg) SizeBytes() int { - return 2 -} - -func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) { - p := primitive.Uint16(m.Foo) - p.MarshalUnsafe(dst) -} - -func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) { - var p primitive.Uint16 - p.UnmarshalUnsafe(src) - m.Foo = uint16(p) -} - func TestParseMessage(t *testing.T) { + dummyNetlinkMsg := primitive.Uint16(0x3130) tests := []struct { desc string input []byte header linux.NetlinkMessageHeader - dataMsg *dummyNetlinkMsg + dataMsg marshal.Marshallable restLen int ok bool }{ @@ -72,9 +53,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 0, ok: true, }, @@ -96,9 +75,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 1, ok: true, }, @@ -119,9 +96,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 0, ok: true, }, @@ -143,9 +118,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 0, ok: true, }, @@ -199,11 +172,11 @@ func TestParseMessage(t *testing.T) { t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header) } - dataMsg := &dummyNetlinkMsg{} - _, dataOk := msg.GetData(dataMsg) + var dataMsg primitive.Uint16 + _, dataOk := msg.GetData(&dataMsg) if !dataOk { t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk) - } else if !reflect.DeepEqual(dataMsg, test.dataMsg) { + } else if !reflect.DeepEqual(&dataMsg, test.dataMsg) { t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg) } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 267155807..19c8f340d 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -223,7 +223,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) { } var sa linux.SockAddrNetlink - sa.UnmarshalUnsafe(b[:sa.SizeBytes()]) + sa.UnmarshalUnsafe(b) if sa.Family != linux.AF_NETLINK { return nil, syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index c35cf06f6..e38c4e5da 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -660,7 +660,7 @@ func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) < sockAddrLinkSize { return syserr.ErrInvalidArgument } - a.UnmarshalBytes(sockaddr[:sockAddrLinkSize]) + a.UnmarshalBytes(sockaddr) addr = tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), @@ -1839,7 +1839,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - v.UnmarshalBytes(optVal[:linux.SizeOfTimeval]) + v.UnmarshalBytes(optVal) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1852,7 +1852,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - v.UnmarshalBytes(optVal[:linux.SizeOfTimeval]) + v.UnmarshalBytes(optVal) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1883,7 +1883,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Linger - v.UnmarshalBytes(optVal[:linux.SizeOfLinger]) + v.UnmarshalBytes(optVal) if v != (linux.Linger{}) { socket.SetSockOptEmitUnimplementedEvent(t, name) @@ -2222,12 +2222,12 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR if len(optVal) >= inetMulticastRequestWithNICSize { var req linux.InetMulticastRequestWithNIC - req.UnmarshalUnsafe(optVal[:inetMulticastRequestWithNICSize]) + req.UnmarshalUnsafe(optVal) return req, nil } var req linux.InetMulticastRequestWithNIC - req.InetMulticastRequest.UnmarshalUnsafe(optVal[:inetMulticastRequestSize]) + req.InetMulticastRequest.UnmarshalUnsafe(optVal) return req, nil } @@ -2237,7 +2237,7 @@ func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syse } var req linux.Inet6MulticastRequest - req.UnmarshalUnsafe(optVal[:inet6MulticastRequestSize]) + req.UnmarshalUnsafe(optVal) return req, nil } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index fc5431eb1..01073df72 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -595,19 +595,19 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { switch family { case unix.AF_INET: var addr linux.SockAddrInet - addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(data) return &addr case unix.AF_INET6: var addr linux.SockAddrInet6 - addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(data) return &addr case unix.AF_UNIX: var addr linux.SockAddrUnix - addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(data) return &addr case unix.AF_NETLINK: var addr linux.SockAddrNetlink - addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(data) return &addr default: panic(fmt.Sprintf("Unsupported socket family %v", family)) @@ -738,7 +738,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInetSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - a.UnmarshalUnsafe(addr[:sockAddrInetSize]) + a.UnmarshalUnsafe(addr) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -751,7 +751,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInet6Size { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - a.UnmarshalUnsafe(addr[:sockAddrInet6Size]) + a.UnmarshalUnsafe(addr) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -767,7 +767,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrLinkSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - a.UnmarshalUnsafe(addr[:sockAddrLinkSize]) + a.UnmarshalUnsafe(addr) // TODO(https://gvisor.dev/issue/6530): Do not assume all interfaces have // an ethernet address. if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index f4aab25b0..b164d9107 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -161,12 +161,10 @@ var controlMessageType = map[int32]string{ linux.SO_TIMESTAMP: "SO_TIMESTAMP", } -func unmarshalControlMessageRights(src []byte) linux.ControlMessageRights { +func unmarshalControlMessageRights(src []byte) []primitive.Int32 { count := len(src) / linux.SizeOfControlMessageRight - cmr := make(linux.ControlMessageRights, count) - for i, _ := range cmr { - cmr[i] = int32(hostarch.ByteOrder.Uint32(src[i*linux.SizeOfControlMessageRight:])) - } + cmr := make([]primitive.Int32, count) + primitive.UnmarshalUnsafeInt32Slice(cmr, src) return cmr } @@ -182,14 +180,14 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) var strs []string - for i := 0; i < len(buf); { - if i+linux.SizeOfControlMessageHeader > len(buf) { + for len(buf) > 0 { + if linux.SizeOfControlMessageHeader > len(buf) { strs = append(strs, "{invalid control message (too short)}") break } var h linux.ControlMessageHeader - h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) + buf = h.UnmarshalUnsafe(buf) var skipData bool level := "SOL_SOCKET" @@ -204,7 +202,9 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) typ = fmt.Sprint(h.Type) } - if h.Length > uint64(len(buf)-i) { + width := t.Arch().Width() + length := int(h.Length) - linux.SizeOfControlMessageHeader + if length > len(buf) { strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, content extends beyond buffer}", level, @@ -214,9 +214,6 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) break } - i += linux.SizeOfControlMessageHeader - width := t.Arch().Width() - length := int(h.Length) - linux.SizeOfControlMessageHeader if length < 0 { strs = append(strs, fmt.Sprintf( "{level=%s, type=%s, length=%d, content too short}", @@ -229,78 +226,80 @@ func cmsghdr(t *kernel.Task, addr hostarch.Addr, length uint64, maxBytes uint64) if skipData { strs = append(strs, fmt.Sprintf("{level=%s, type=%s, length=%d}", level, typ, h.Length)) - i += bits.AlignUp(length, width) - continue - } - - switch h.Type { - case linux.SCM_RIGHTS: - rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) - fds := unmarshalControlMessageRights(buf[i : i+rightsSize]) - rights := make([]string, 0, len(fds)) - for _, fd := range fds { - rights = append(rights, fmt.Sprint(fd)) - } + } else { + switch h.Type { + case linux.SCM_RIGHTS: + rightsSize := bits.AlignDown(length, linux.SizeOfControlMessageRight) + fds := unmarshalControlMessageRights(buf[:rightsSize]) + rights := make([]string, 0, len(fds)) + for _, fd := range fds { + rights = append(rights, fmt.Sprint(fd)) + } - strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, content: %s}", - level, - typ, - h.Length, - strings.Join(rights, ","), - )) - - case linux.SCM_CREDENTIALS: - if length < linux.SizeOfControlMessageCredentials { strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, content too short}", + "{level=%s, type=%s, length=%d, content: %s}", level, typ, h.Length, + strings.Join(rights, ","), )) - break - } - var creds linux.ControlMessageCredentials - creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) + case linux.SCM_CREDENTIALS: + if length < linux.SizeOfControlMessageCredentials { + strs = append(strs, fmt.Sprintf( + "{level=%s, type=%s, length=%d, content too short}", + level, + typ, + h.Length, + )) + break + } - strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}", - level, - typ, - h.Length, - creds.PID, - creds.UID, - creds.GID, - )) + var creds linux.ControlMessageCredentials + creds.UnmarshalUnsafe(buf) - case linux.SO_TIMESTAMP: - if length < linux.SizeOfTimeval { strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, content too short}", + "{level=%s, type=%s, length=%d, pid: %d, uid: %d, gid: %d}", level, typ, h.Length, + creds.PID, + creds.UID, + creds.GID, )) - break - } - var tv linux.Timeval - tv.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) + case linux.SO_TIMESTAMP: + if length < linux.SizeOfTimeval { + strs = append(strs, fmt.Sprintf( + "{level=%s, type=%s, length=%d, content too short}", + level, + typ, + h.Length, + )) + break + } - strs = append(strs, fmt.Sprintf( - "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}", - level, - typ, - h.Length, - tv.Sec, - tv.Usec, - )) + var tv linux.Timeval + tv.UnmarshalUnsafe(buf) - default: - panic("unreachable") + strs = append(strs, fmt.Sprintf( + "{level=%s, type=%s, length=%d, Sec: %d, Usec: %d}", + level, + typ, + h.Length, + tv.Sec, + tv.Usec, + )) + + default: + panic("unreachable") + } + } + if shift := bits.AlignUp(length, width); shift > len(buf) { + buf = buf[:0] + } else { + buf = buf[shift:] } - i += bits.AlignUp(length, width) } return fmt.Sprintf("%#x %s", addr, strings.Join(strs, ", ")) diff --git a/tools/go_marshal/gomarshal/generator_interfaces.go b/tools/go_marshal/gomarshal/generator_interfaces.go index 3e643e77f..db135fd74 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces.go +++ b/tools/go_marshal/gomarshal/generator_interfaces.go @@ -132,8 +132,7 @@ func (g *interfaceGenerator) marshalScalar(accessor, typ, bufVar string) { g.emit("hostarch.ByteOrder.PutUint64(%s[:8], uint64(%s))\n", bufVar, accessor) g.shift(bufVar, 8) default: - g.emit("%s.MarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) - g.shiftDynamic(bufVar, accessor) + g.emit("%s = %s.MarshalBytes(%s)\n", bufVar, accessor, bufVar) } } @@ -159,8 +158,7 @@ func (g *interfaceGenerator) unmarshalScalar(accessor, typ, bufVar string) { g.emit("%s = %s(hostarch.ByteOrder.Uint64(%s[:8]))\n", accessor, typ, bufVar) g.shift(bufVar, 8) default: - g.emit("%s.UnmarshalBytes(%s[:%s.SizeBytes()])\n", accessor, bufVar, accessor) - g.shiftDynamic(bufVar, accessor) + g.emit("%s = %s.UnmarshalBytes(%s)\n", bufVar, accessor, bufVar) g.recordPotentiallyNonPackedField(accessor) } } diff --git a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go index bd7741ae5..1f98d9246 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_array_newtype.go @@ -56,24 +56,26 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *as g.emit("}\n\n") g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) MarshalBytes(dst []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) g.inIndent(func() { g.marshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "dst") }) g.emit("}\n") + g.emit("return dst\n") }) g.emit("}\n\n") g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) UnmarshalBytes(src []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { g.emit("for idx := 0; idx < %s; idx++ {\n", lenExpr) g.inIndent(func() { g.unmarshalScalar(fmt.Sprintf("%s[idx]", g.r), elt.Name, "src") }) g.emit("}\n") + g.emit("return src\n") }) g.emit("}\n\n") @@ -87,16 +89,20 @@ func (g *interfaceGenerator) emitMarshallableForArrayNewtype(n *ast.Ident, a *as g.emit("}\n\n") g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&%s[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) + g.emit("size := %s.SizeBytes()\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(&%s[0]), uintptr(size))\n", g.r) + g.emit("return dst[size:]\n") }) g.emit("}\n\n") g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) + g.emit("size := %s.SizeBytes()\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(size))\n", g.r) + g.emit("return src[size:]\n") }) g.emit("}\n\n") diff --git a/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go b/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go index 345020ddc..70ae8ef4a 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_dynamic.go @@ -28,18 +28,18 @@ func (g *interfaceGenerator) emitMarshallableForDynamicType() { g.emit("}\n\n") g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) - g.emit("%s.MarshalBytes(dst)\n", g.r) + g.emit("return %s.MarshalBytes(dst)\n", g.r) }) g.emit("}\n\n") g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { g.emit("// Type %s doesn't have a packed layout in memory, fallback to UnmarshalBytes.\n", g.typeName()) - g.emit("%s.UnmarshalBytes(src)\n", g.r) + g.emit("return %s.UnmarshalBytes(src)\n", g.r) }) g.emit("}\n\n") diff --git a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go index ba4b7324e..e2387f032 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_primitive_newtype.go @@ -116,16 +116,26 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) g.emit("}\n\n") g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) MarshalBytes(dst []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { g.marshalPrimitiveScalar(g.r, nt.Name, "dst") + if size, dynamic := g.scalarSize(nt); !dynamic { + g.emit("return dst[%d:]\n", size) + } else { + g.emit("return dst[(*%s)(nil).SizeBytes():]\n", nt.Name) + } }) g.emit("}\n\n") g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) UnmarshalBytes(src []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { g.unmarshalPrimitiveScalar(g.r, nt.Name, "src", g.typeName()) + if size, dynamic := g.scalarSize(nt); !dynamic { + g.emit("return src[%d:]\n", size) + } else { + g.emit("return src[(*%s)(nil).SizeBytes():]\n", nt.Name) + } }) g.emit("}\n\n") @@ -139,16 +149,20 @@ func (g *interfaceGenerator) emitMarshallableForPrimitiveNewtype(nt *ast.Ident) g.emit("}\n\n") g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r) + g.emit("size := %s.SizeBytes()\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(size))\n", g.r) + g.emit("return dst[size:]\n") }) g.emit("}\n\n") g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) + g.emit("size := %s.SizeBytes()\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(size))\n", g.r) + g.emit("return src[size:]\n") }) g.emit("}\n\n") diff --git a/tools/go_marshal/gomarshal/generator_interfaces_struct.go b/tools/go_marshal/gomarshal/generator_interfaces_struct.go index 4c47218f1..21177d39c 100644 --- a/tools/go_marshal/gomarshal/generator_interfaces_struct.go +++ b/tools/go_marshal/gomarshal/generator_interfaces_struct.go @@ -29,7 +29,7 @@ func (g *interfaceGenerator) fieldAccessor(n *ast.Ident) string { } // areFieldsPackedExpression returns a go expression checking whether g.t's fields are -// packed. Returns "", false if g.t has no fields that may be potentially +// packed. Returns "", false if g.t has no fields that may be potentially not // packed, otherwise returns <clause>, true, where <clause> is an expression // like "t.a.Packed() && t.b.Packed() && t.c.Packed()". func (g *interfaceGenerator) areFieldsPackedExpression() (string, bool) { @@ -136,7 +136,7 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("\n}\n\n") g.emit("// MarshalBytes implements marshal.Marshallable.MarshalBytes.\n") - g.emit("func (%s *%s) MarshalBytes(dst []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) MarshalBytes(dst []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { forEachStructField(st, fieldDispatcher{ primitive: func(n, t *ast.Ident) { @@ -186,11 +186,13 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n") }, }.dispatch) + // All cases above shift the buffer appropriately. + g.emit("return dst\n") }) g.emit("}\n\n") g.emit("// UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes.\n") - g.emit("func (%s *%s) UnmarshalBytes(src []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) UnmarshalBytes(src []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { forEachStructField(st, fieldDispatcher{ primitive: func(n, t *ast.Ident) { @@ -242,6 +244,8 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n") }, }.dispatch) + // All cases above shift the buffer appropriately. + g.emit("return src\n") }) g.emit("}\n\n") @@ -263,25 +267,27 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n\n") g.emit("// MarshalUnsafe implements marshal.Marshallable.MarshalUnsafe.\n") - g.emit("func (%s *%s) MarshalUnsafe(dst []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) MarshalUnsafe(dst []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fallback to MarshalBytes.\n", g.typeName()) - g.emit("%s.MarshalBytes(dst)\n", g.r) + g.emit("return %s.MarshalBytes(dst)\n", g.r) } if thisPacked { g.recordUsedImport("gohacks") g.recordUsedImport("unsafe") + fastMarshal := func() { + g.emit("size := %s.SizeBytes()\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(size))\n", g.r) + g.emit("return dst[size:]\n") + } if cond, ok := g.areFieldsPackedExpression(); ok { g.emit("if %s {\n", cond) - g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r) - }) - g.emit("} else {\n") - g.inIndent(fallback) + g.inIndent(fastMarshal) g.emit("}\n") + fallback() } else { - g.emit("gohacks.Memmove(unsafe.Pointer(&dst[0]), unsafe.Pointer(%s), uintptr(%s.SizeBytes()))\n", g.r, g.r) + fastMarshal() } } else { fallback() @@ -290,24 +296,27 @@ func (g *interfaceGenerator) emitMarshallableForStruct(st *ast.StructType) { g.emit("}\n\n") g.emit("// UnmarshalUnsafe implements marshal.Marshallable.UnmarshalUnsafe.\n") - g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) {\n", g.r, g.typeName()) + g.emit("func (%s *%s) UnmarshalUnsafe(src []byte) []byte {\n", g.r, g.typeName()) g.inIndent(func() { fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fallback to UnmarshalBytes.\n", g.typeName()) - g.emit("%s.UnmarshalBytes(src)\n", g.r) + g.emit("return %s.UnmarshalBytes(src)\n", g.r) } if thisPacked { g.recordUsedImport("gohacks") + g.recordUsedImport("unsafe") + fastUnmarshal := func() { + g.emit("size := %s.SizeBytes()\n", g.r) + g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(size))\n", g.r) + g.emit("return src[size:]\n") + } if cond, ok := g.areFieldsPackedExpression(); ok { g.emit("if %s {\n", cond) - g.inIndent(func() { - g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) - }) - g.emit("} else {\n") - g.inIndent(fallback) + g.inIndent(fastUnmarshal) g.emit("}\n") + fallback() } else { - g.emit("gohacks.Memmove(unsafe.Pointer(%s), unsafe.Pointer(&src[0]), uintptr(%s.SizeBytes()))\n", g.r, g.r) + fastUnmarshal() } } else { fallback() @@ -456,7 +465,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, g.emit("limit := length/size\n") g.emit("for idx := 0; idx < limit; idx++ {\n") g.inIndent(func() { - g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n") + g.emit("buf = dst[idx].UnmarshalBytes(buf)\n") }) g.emit("}\n\n") @@ -465,8 +474,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, g.emit("// result in unmarshalling zero values for some parts of the object.\n") g.emit("if length%size != 0 {\n") g.inIndent(func() { - g.emit("idx := limit\n") - g.emit("dst[idx].UnmarshalBytes(buf[size*idx:size*(idx+1)])\n") + g.emit("dst[limit].UnmarshalBytes(buf)\n") }) g.emit("}\n\n") @@ -507,9 +515,10 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, fallback := func() { g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) g.emit("buf := cc.CopyScratchBuffer(size * count)\n") + g.emit("curBuf := buf\n") g.emit("for idx := 0; idx < count; idx++ {\n") g.inIndent(func() { - g.emit("src[idx].MarshalBytes(buf[size*idx:size*(idx+1)])\n") + g.emit("curBuf = src[idx].MarshalBytes(curBuf)\n") }) g.emit("}\n") g.emit("return cc.CopyOutBytes(addr, buf)\n") @@ -550,7 +559,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, g.emit("// Type %s doesn't have a packed layout in memory, fall back to MarshalBytes.\n", g.typeName()) g.emit("for idx := 0; idx < count; idx++ {\n") g.inIndent(func() { - g.emit("src[idx].MarshalBytes(dst[size*idx:(size)*(idx+1)])\n") + g.emit("dst = src[idx].MarshalBytes(dst)\n") }) g.emit("}\n") g.emit("return size * count, nil\n") @@ -589,7 +598,7 @@ func (g *interfaceGenerator) emitMarshallableSliceForStruct(st *ast.StructType, g.emit("// Type %s doesn't have a packed layout in memory, fall back to UnmarshalBytes.\n", g.typeName()) g.emit("for idx := 0; idx < count; idx++ {\n") g.inIndent(func() { - g.emit("dst[idx].UnmarshalBytes(src[size*idx:size*(idx+1)])\n") + g.emit("src = dst[idx].UnmarshalBytes(src)\n") }) g.emit("}\n") g.emit("return size * count, nil\n") diff --git a/tools/go_marshal/test/dynamic.go b/tools/go_marshal/test/dynamic.go index 9a812efe9..46b446392 100644 --- a/tools/go_marshal/test/dynamic.go +++ b/tools/go_marshal/test/dynamic.go @@ -31,25 +31,26 @@ func (t *Type12Dynamic) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (t *Type12Dynamic) MarshalBytes(dst []byte) { - t.X.MarshalBytes(dst) - dst = dst[t.X.SizeBytes():] - for i, x := range t.Y { - x.MarshalBytes(dst[i*8 : (i+1)*8]) +func (t *Type12Dynamic) MarshalBytes(dst []byte) []byte { + dst = t.X.MarshalBytes(dst) + for _, x := range t.Y { + dst = x.MarshalBytes(dst) } + return dst } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (t *Type12Dynamic) UnmarshalBytes(src []byte) { - t.X.UnmarshalBytes(src) +func (t *Type12Dynamic) UnmarshalBytes(src []byte) []byte { + src = t.X.UnmarshalBytes(src) if t.Y != nil { t.Y = t.Y[:0] } - for i := t.X.SizeBytes(); i < len(src); i += 8 { + for len(src) > 0 { var x primitive.Int64 - x.UnmarshalBytes(src[i:]) + src = x.UnmarshalBytes(src) t.Y = append(t.Y, x) } + return src } // Type13Dynamic is a dynamically sized struct which depends on the @@ -67,17 +68,16 @@ func (t *Type13Dynamic) SizeBytes() int { } // MarshalBytes implements marshal.Marshallable.MarshalBytes. -func (t *Type13Dynamic) MarshalBytes(dst []byte) { +func (t *Type13Dynamic) MarshalBytes(dst []byte) []byte { strLen := primitive.Uint32(len(*t)) - strLen.MarshalBytes(dst) - dst = dst[strLen.SizeBytes():] - copy(dst[:strLen], *t) + dst = strLen.MarshalBytes(dst) + return dst[copy(dst[:strLen], *t):] } // UnmarshalBytes implements marshal.Marshallable.UnmarshalBytes. -func (t *Type13Dynamic) UnmarshalBytes(src []byte) { +func (t *Type13Dynamic) UnmarshalBytes(src []byte) []byte { var strLen primitive.Uint32 - strLen.UnmarshalBytes(src) - src = src[strLen.SizeBytes():] + src = strLen.UnmarshalBytes(src) *t = Type13Dynamic(src[:strLen]) + return src[strLen:] } |