diff options
author | Ayush Ranjan <ayushranjan@google.com> | 2021-11-05 10:40:53 -0700 |
---|---|---|
committer | gVisor bot <gvisor-bot@google.com> | 2021-11-05 10:43:49 -0700 |
commit | ce4f4283badb6b07baf9f8e6d99e7a5fd15c92db (patch) | |
tree | 848dc50da62da59dc4a5781f9eb7461c58b71512 /pkg/sentry/socket | |
parent | 822a647018adbd994114cb0dc8932f2853b805aa (diff) |
Make {Un}Marshal{Bytes/Unsafe} return remaining buffer.
Change marshal.Marshallable method signatures to return the remaining buffer.
This makes it easier to implement these method manually. Without this, we would
have to manually do buffer shifting which is error prone.
tools/go_marshal/test:benchmark test does not show change in performance.
Additionally fixed some marshalling bugs in fsimpl/fuse.
Updated multiple callpoints to get rid of redundant slice indexing work and
simplified code using this new signature.
Updates #6450
PiperOrigin-RevId: 407857019
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r-- | pkg/sentry/socket/control/control.go | 63 | ||||
-rw-r--r-- | pkg/sentry/socket/control/control_vfs2.go | 5 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/socket.go | 14 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/stack.go | 6 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/extensions.go | 5 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/ipv4.go | 5 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/ipv6.go | 5 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/netfilter.go | 7 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/owner_matcher.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/targets.go | 15 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/tcp_matcher.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netfilter/udp_matcher.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/message_test.go | 45 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/socket.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 14 | ||||
-rw-r--r-- | pkg/sentry/socket/socket.go | 14 |
16 files changed, 82 insertions, 124 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 6077b2150..4b036b323 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -17,6 +17,8 @@ package control import ( + "time" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/bits" "gvisor.dev/gvisor/pkg/context" @@ -29,7 +31,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "time" ) // SCMCredentials represents a SCM_CREDENTIALS socket control message. @@ -63,10 +64,10 @@ type RightsFiles []*fs.File // NewSCMRights creates a new SCM_RIGHTS socket control message representation // using local sentry FDs. -func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) { +func NewSCMRights(t *kernel.Task, fds []primitive.Int32) (SCMRights, error) { files := make(RightsFiles, 0, len(fds)) for _, fd := range fds { - file := t.GetFile(fd) + file := t.GetFile(int32(fd)) if file == nil { files.Release(t) return nil, linuxerr.EBADF @@ -486,26 +487,25 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) (socket.ControlMessages, error) { var ( cmsgs socket.ControlMessages - fds linux.ControlMessageRights + fds []primitive.Int32 ) - for i := 0; i < len(buf); { - if i+linux.SizeOfControlMessageHeader > len(buf) { + for len(buf) > 0 { + if linux.SizeOfControlMessageHeader > len(buf) { return cmsgs, linuxerr.EINVAL } var h linux.ControlMessageHeader - h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader]) + buf = h.UnmarshalUnsafe(buf) if h.Length < uint64(linux.SizeOfControlMessageHeader) { return socket.ControlMessages{}, linuxerr.EINVAL } - if h.Length > uint64(len(buf)-i) { - return socket.ControlMessages{}, linuxerr.EINVAL - } - i += linux.SizeOfControlMessageHeader length := int(h.Length) - linux.SizeOfControlMessageHeader + if length > len(buf) { + return socket.ControlMessages{}, linuxerr.EINVAL + } switch h.Level { case linux.SOL_SOCKET: @@ -518,11 +518,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, linuxerr.EINVAL } - for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight { - fds = append(fds, int32(hostarch.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) - } - - i += bits.AlignUp(length, width) + curFDs := make([]primitive.Int32, numRights) + primitive.UnmarshalUnsafeInt32Slice(curFDs, buf[:rightsSize]) + fds = append(fds, curFDs...) case linux.SCM_CREDENTIALS: if length < linux.SizeOfControlMessageCredentials { @@ -530,23 +528,21 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } var creds linux.ControlMessageCredentials - creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials]) + creds.UnmarshalUnsafe(buf) scmCreds, err := NewSCMCredentials(t, creds) if err != nil { return socket.ControlMessages{}, err } cmsgs.Unix.Credentials = scmCreds - i += bits.AlignUp(length, width) case linux.SO_TIMESTAMP: if length < linux.SizeOfTimeval { return socket.ControlMessages{}, linuxerr.EINVAL } var ts linux.Timeval - ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval]) + ts.UnmarshalUnsafe(buf) cmsgs.IP.Timestamp = ts.ToTime() cmsgs.IP.HasTimestamp = true - i += bits.AlignUp(length, width) default: // Unknown message type. @@ -560,9 +556,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } cmsgs.IP.HasTOS = true var tos primitive.Uint8 - tos.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTOS]) + tos.UnmarshalUnsafe(buf) cmsgs.IP.TOS = uint8(tos) - i += bits.AlignUp(length, width) case linux.IP_PKTINFO: if length < linux.SizeOfControlMessageIPPacketInfo { @@ -571,19 +566,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) cmsgs.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - packetInfo.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageIPPacketInfo]) - + packetInfo.UnmarshalUnsafe(buf) cmsgs.IP.PacketInfo = packetInfo - i += bits.AlignUp(length, width) case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet if length < addr.SizeBytes() { return socket.ControlMessages{}, linuxerr.EINVAL } - addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) + addr.UnmarshalUnsafe(buf) cmsgs.IP.OriginalDstAddress = &addr - i += bits.AlignUp(length, width) case linux.IP_RECVERR: var errCmsg linux.SockErrCMsgIPv4 @@ -591,9 +583,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, linuxerr.EINVAL } - errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + errCmsg.UnmarshalBytes(buf) cmsgs.IP.SockErr = &errCmsg - i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, linuxerr.EINVAL @@ -606,18 +597,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) } cmsgs.IP.HasTClass = true var tclass primitive.Uint32 - tclass.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTClass]) + tclass.UnmarshalUnsafe(buf) cmsgs.IP.TClass = uint32(tclass) - i += bits.AlignUp(length, width) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 if length < addr.SizeBytes() { return socket.ControlMessages{}, linuxerr.EINVAL } - addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()]) + addr.UnmarshalUnsafe(buf) cmsgs.IP.OriginalDstAddress = &addr - i += bits.AlignUp(length, width) case linux.IPV6_RECVERR: var errCmsg linux.SockErrCMsgIPv6 @@ -625,9 +614,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) return socket.ControlMessages{}, linuxerr.EINVAL } - errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()]) + errCmsg.UnmarshalBytes(buf) cmsgs.IP.SockErr = &errCmsg - i += bits.AlignUp(length, width) default: return socket.ControlMessages{}, linuxerr.EINVAL @@ -635,6 +623,11 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) default: return socket.ControlMessages{}, linuxerr.EINVAL } + if shift := bits.AlignUp(length, width); shift > len(buf) { + buf = buf[:0] + } else { + buf = buf[shift:] + } } if cmsgs.Unix.Credentials == nil { diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go index 0a989cbeb..a638cb955 100644 --- a/pkg/sentry/socket/control/control_vfs2.go +++ b/pkg/sentry/socket/control/control_vfs2.go @@ -18,6 +18,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -45,10 +46,10 @@ type RightsFilesVFS2 []*vfs.FileDescription // NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message // representation using local sentry FDs. -func NewSCMRightsVFS2(t *kernel.Task, fds []int32) (SCMRightsVFS2, error) { +func NewSCMRightsVFS2(t *kernel.Task, fds []primitive.Int32) (SCMRightsVFS2, error) { files := make(RightsFilesVFS2, 0, len(fds)) for _, fd := range fds { - file := t.GetFileVFS2(fd) + file := t.GetFileVFS2(int32(fd)) if file == nil { files.Release(t) return nil, linuxerr.EBADF diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 6e2318f75..a31f3ebec 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -578,7 +578,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.SO_TIMESTAMP: controlMessages.IP.HasTimestamp = true ts := linux.Timeval{} - ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval]) + ts.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.Timestamp = ts.ToTime() } @@ -587,18 +587,18 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.IP_TOS: controlMessages.IP.HasTOS = true var tos primitive.Uint8 - tos.UnmarshalUnsafe(unixCmsg.Data[:tos.SizeBytes()]) + tos.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.TOS = uint8(tos) case linux.IP_PKTINFO: controlMessages.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo - packetInfo.UnmarshalUnsafe(unixCmsg.Data[:packetInfo.SizeBytes()]) + packetInfo.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.PacketInfo = packetInfo case linux.IP_RECVORIGDSTADDR: var addr linux.SockAddrInet - addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.OriginalDstAddress = &addr case unix.IP_RECVERR: @@ -612,12 +612,12 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.IPV6_TCLASS: controlMessages.IP.HasTClass = true var tclass primitive.Uint32 - tclass.UnmarshalUnsafe(unixCmsg.Data[:tclass.SizeBytes()]) + tclass.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.TClass = uint32(tclass) case linux.IPV6_RECVORIGDSTADDR: var addr linux.SockAddrInet6 - addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.OriginalDstAddress = &addr case unix.IPV6_RECVERR: @@ -631,7 +631,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s case linux.TCP_INQ: controlMessages.IP.HasInq = true var inq primitive.Int32 - inq.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfControlMessageInq]) + inq.UnmarshalUnsafe(unixCmsg.Data) controlMessages.IP.Inq = int32(inq) } } diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 61111ac6c..c84ab3fb7 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -138,7 +138,7 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), unix.SizeofIfInfomsg) } var ifinfo linux.InterfaceInfoMessage - ifinfo.UnmarshalUnsafe(link.Data[:ifinfo.SizeBytes()]) + ifinfo.UnmarshalUnsafe(link.Data) inetIF := inet.Interface{ DeviceType: ifinfo.Type, Flags: ifinfo.Flags, @@ -169,7 +169,7 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), unix.SizeofIfAddrmsg) } var ifaddr linux.InterfaceAddrMessage - ifaddr.UnmarshalUnsafe(addr.Data[:ifaddr.SizeBytes()]) + ifaddr.UnmarshalUnsafe(addr.Data) inetAddr := inet.InterfaceAddr{ Family: ifaddr.Family, PrefixLen: ifaddr.PrefixLen, @@ -201,7 +201,7 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) } var ifRoute linux.RouteMessage - ifRoute.UnmarshalUnsafe(routeMsg.Data[:ifRoute.SizeBytes()]) + ifRoute.UnmarshalUnsafe(routeMsg.Data) inetRoute := inet.Route{ Family: ifRoute.Family, DstLen: ifRoute.DstLen, diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 3f1b4a17b..7606d2bbb 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -80,9 +80,8 @@ func marshalEntryMatch(name string, data []byte) []byte { copy(matcher.Name[:], name) buf := make([]byte, size) - entryLen := matcher.XTEntryMatch.SizeBytes() - matcher.XTEntryMatch.MarshalUnsafe(buf[:entryLen]) - copy(buf[entryLen:], matcher.Data) + bufRemain := matcher.XTEntryMatch.MarshalUnsafe(buf) + copy(bufRemain, matcher.Data) return buf } diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index af31cbc5b..6cbfee8b6 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -141,10 +141,9 @@ func modifyEntries4(task *kernel.Task, stk *stack.Stack, optVal []byte, replace nflog("optVal has insufficient size for entry %d", len(optVal)) return nil, syserr.ErrInvalidArgument } - var entry linux.IPTEntry - entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()]) initialOptValLen := len(optVal) - optVal = optVal[entry.SizeBytes():] + var entry linux.IPTEntry + optVal = entry.UnmarshalUnsafe(optVal) if entry.TargetOffset < linux.SizeOfIPTEntry { nflog("entry has too-small target offset %d", entry.TargetOffset) diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 6cefe0b9c..902707abf 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -144,10 +144,9 @@ func modifyEntries6(task *kernel.Task, stk *stack.Stack, optVal []byte, replace nflog("optVal has insufficient size for entry %d", len(optVal)) return nil, syserr.ErrInvalidArgument } - var entry linux.IP6TEntry - entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()]) initialOptValLen := len(optVal) - optVal = optVal[entry.SizeBytes():] + var entry linux.IP6TEntry + optVal = entry.UnmarshalUnsafe(optVal) if entry.TargetOffset < linux.SizeOfIP6TEntry { nflog("entry has too-small target offset %d", entry.TargetOffset) diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 01f2f8c77..f2cdf091d 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -176,9 +176,7 @@ func setHooksAndUnderflow(info *linux.IPTGetinfo, table stack.Table, offset uint // net/ipv4/netfilter/ip_tables.c:translate_table for reference. func SetEntries(task *kernel.Task, stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { var replace linux.IPTReplace - replaceBuf := optVal[:linux.SizeOfIPTReplace] - optVal = optVal[linux.SizeOfIPTReplace:] - replace.UnmarshalBytes(replaceBuf) + optVal = replace.UnmarshalBytes(optVal) var table stack.Table switch replace.Name.String() { @@ -306,8 +304,7 @@ func parseMatchers(task *kernel.Task, filter stack.IPHeaderFilter, optVal []byte return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal)) } var match linux.XTEntryMatch - buf := optVal[:match.SizeBytes()] - match.UnmarshalUnsafe(buf) + match.UnmarshalUnsafe(optVal) nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match) // Check some invariants. diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go index 6eff2ae65..d83d7e535 100644 --- a/pkg/sentry/socket/netfilter/owner_matcher.go +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -73,7 +73,7 @@ func (ownerMarshaler) unmarshal(task *kernel.Task, buf []byte, filter stack.IPHe // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.IPTOwnerInfo - matchData.UnmarshalUnsafe(buf[:linux.SizeOfIPTOwnerInfo]) + matchData.UnmarshalUnsafe(buf) nflog("parsed IPTOwnerInfo: %+v", matchData) var owner OwnerMatcher diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index b9c15daab..eaf601543 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -199,7 +199,7 @@ func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } var standardTarget linux.XTStandardTarget - standardTarget.UnmarshalUnsafe(buf[:standardTarget.SizeBytes()]) + standardTarget.UnmarshalUnsafe(buf) if standardTarget.Verdict < 0 { // A Verdict < 0 indicates a non-jump verdict. @@ -253,7 +253,6 @@ func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar return nil, syserr.ErrInvalidArgument } var errTgt linux.XTErrorTarget - buf = buf[:linux.SizeOfXTErrorTarget] errTgt.UnmarshalUnsafe(buf) // Error targets are used in 2 cases: @@ -316,7 +315,6 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( } var rt linux.XTRedirectTarget - buf = buf[:linux.SizeOfXTRedirectTarget] rt.UnmarshalUnsafe(buf) // Copy linux.XTRedirectTarget to stack.RedirectTarget. @@ -405,8 +403,7 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar } var natRange linux.NFNATRange - buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] - natRange.UnmarshalUnsafe(buf) + natRange.UnmarshalUnsafe(buf[linux.SizeOfXTEntryTarget:]) // We don't support port or address ranges. if natRange.MinAddr != natRange.MaxAddr { @@ -477,7 +474,6 @@ func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta } var st linux.XTSNATTarget - buf = buf[:linux.SizeOfXTSNATTarget] st.UnmarshalUnsafe(buf) // Copy linux.XTSNATTarget to stack.SNATTarget. @@ -557,8 +553,7 @@ func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta } var natRange linux.NFNATRange - buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize] - natRange.UnmarshalUnsafe(buf) + natRange.UnmarshalUnsafe(buf[linux.SizeOfXTEntryTarget:]) // TODO(gvisor.dev/issue/5697): Support port or address ranges. if natRange.MinAddr != natRange.MaxAddr { @@ -621,7 +616,9 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.T return nil, syserr.ErrInvalidArgument } var target linux.XTEntryTarget - target.UnmarshalUnsafe(optVal[:target.SizeBytes()]) + // Do not advance optVal as targetMake.unmarshal() may unmarshal + // XTEntryTarget again but with some added fields. + target.UnmarshalUnsafe(optVal) return unmarshalTarget(target, filter, optVal) } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index e5b73a976..a621a6a16 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -59,7 +59,7 @@ func (tcpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderF // For alignment reasons, the match's total size may // exceed what's strictly necessary to hold matchData. var matchData linux.XTTCP - matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()]) + matchData.UnmarshalUnsafe(buf) nflog("parseMatchers: parsed XTTCP: %+v", matchData) if matchData.Option != 0 || diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index aa72ee70c..2ca854764 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -59,7 +59,7 @@ func (udpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderF // For alignment reasons, the match's total size may exceed what's // strictly necessary to hold matchData. var matchData linux.XTUDP - matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()]) + matchData.UnmarshalUnsafe(buf) nflog("parseMatchers: parsed XTUDP: %+v", matchData) if matchData.InverseFlags != 0 { diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go index 968968469..1604b2792 100644 --- a/pkg/sentry/socket/netlink/message_test.go +++ b/pkg/sentry/socket/netlink/message_test.go @@ -25,33 +25,14 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/netlink" ) -type dummyNetlinkMsg struct { - marshal.StubMarshallable - Foo uint16 -} - -func (*dummyNetlinkMsg) SizeBytes() int { - return 2 -} - -func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) { - p := primitive.Uint16(m.Foo) - p.MarshalUnsafe(dst) -} - -func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) { - var p primitive.Uint16 - p.UnmarshalUnsafe(src) - m.Foo = uint16(p) -} - func TestParseMessage(t *testing.T) { + dummyNetlinkMsg := primitive.Uint16(0x3130) tests := []struct { desc string input []byte header linux.NetlinkMessageHeader - dataMsg *dummyNetlinkMsg + dataMsg marshal.Marshallable restLen int ok bool }{ @@ -72,9 +53,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 0, ok: true, }, @@ -96,9 +75,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 1, ok: true, }, @@ -119,9 +96,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 0, ok: true, }, @@ -143,9 +118,7 @@ func TestParseMessage(t *testing.T) { Seq: 3, PortID: 4, }, - dataMsg: &dummyNetlinkMsg{ - Foo: 0x3130, - }, + dataMsg: &dummyNetlinkMsg, restLen: 0, ok: true, }, @@ -199,11 +172,11 @@ func TestParseMessage(t *testing.T) { t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header) } - dataMsg := &dummyNetlinkMsg{} - _, dataOk := msg.GetData(dataMsg) + var dataMsg primitive.Uint16 + _, dataOk := msg.GetData(&dataMsg) if !dataOk { t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk) - } else if !reflect.DeepEqual(dataMsg, test.dataMsg) { + } else if !reflect.DeepEqual(&dataMsg, test.dataMsg) { t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg) } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 267155807..19c8f340d 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -223,7 +223,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) { } var sa linux.SockAddrNetlink - sa.UnmarshalUnsafe(b[:sa.SizeBytes()]) + sa.UnmarshalUnsafe(b) if sa.Family != linux.AF_NETLINK { return nil, syserr.ErrInvalidArgument diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index c35cf06f6..e38c4e5da 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -660,7 +660,7 @@ func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) < sockAddrLinkSize { return syserr.ErrInvalidArgument } - a.UnmarshalBytes(sockaddr[:sockAddrLinkSize]) + a.UnmarshalBytes(sockaddr) addr = tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), @@ -1839,7 +1839,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - v.UnmarshalBytes(optVal[:linux.SizeOfTimeval]) + v.UnmarshalBytes(optVal) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1852,7 +1852,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Timeval - v.UnmarshalBytes(optVal[:linux.SizeOfTimeval]) + v.UnmarshalBytes(optVal) if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { return syserr.ErrDomain } @@ -1883,7 +1883,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } var v linux.Linger - v.UnmarshalBytes(optVal[:linux.SizeOfLinger]) + v.UnmarshalBytes(optVal) if v != (linux.Linger{}) { socket.SetSockOptEmitUnimplementedEvent(t, name) @@ -2222,12 +2222,12 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR if len(optVal) >= inetMulticastRequestWithNICSize { var req linux.InetMulticastRequestWithNIC - req.UnmarshalUnsafe(optVal[:inetMulticastRequestWithNICSize]) + req.UnmarshalUnsafe(optVal) return req, nil } var req linux.InetMulticastRequestWithNIC - req.InetMulticastRequest.UnmarshalUnsafe(optVal[:inetMulticastRequestSize]) + req.InetMulticastRequest.UnmarshalUnsafe(optVal) return req, nil } @@ -2237,7 +2237,7 @@ func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syse } var req linux.Inet6MulticastRequest - req.UnmarshalUnsafe(optVal[:inet6MulticastRequestSize]) + req.UnmarshalUnsafe(optVal) return req, nil } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index fc5431eb1..01073df72 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -595,19 +595,19 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { switch family { case unix.AF_INET: var addr linux.SockAddrInet - addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(data) return &addr case unix.AF_INET6: var addr linux.SockAddrInet6 - addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(data) return &addr case unix.AF_UNIX: var addr linux.SockAddrUnix - addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(data) return &addr case unix.AF_NETLINK: var addr linux.SockAddrNetlink - addr.UnmarshalUnsafe(data[:addr.SizeBytes()]) + addr.UnmarshalUnsafe(data) return &addr default: panic(fmt.Sprintf("Unsupported socket family %v", family)) @@ -738,7 +738,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInetSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - a.UnmarshalUnsafe(addr[:sockAddrInetSize]) + a.UnmarshalUnsafe(addr) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -751,7 +751,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrInet6Size { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - a.UnmarshalUnsafe(addr[:sockAddrInet6Size]) + a.UnmarshalUnsafe(addr) out := tcpip.FullAddress{ Addr: BytesToIPAddress(a.Addr[:]), @@ -767,7 +767,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { if len(addr) < sockAddrLinkSize { return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - a.UnmarshalUnsafe(addr[:sockAddrLinkSize]) + a.UnmarshalUnsafe(addr) // TODO(https://gvisor.dev/issue/6530): Do not assume all interfaces have // an ethernet address. if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { |