summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
authorAyush Ranjan <ayushranjan@google.com>2021-11-05 10:40:53 -0700
committergVisor bot <gvisor-bot@google.com>2021-11-05 10:43:49 -0700
commitce4f4283badb6b07baf9f8e6d99e7a5fd15c92db (patch)
tree848dc50da62da59dc4a5781f9eb7461c58b71512 /pkg/sentry/socket
parent822a647018adbd994114cb0dc8932f2853b805aa (diff)
Make {Un}Marshal{Bytes/Unsafe} return remaining buffer.
Change marshal.Marshallable method signatures to return the remaining buffer. This makes it easier to implement these method manually. Without this, we would have to manually do buffer shifting which is error prone. tools/go_marshal/test:benchmark test does not show change in performance. Additionally fixed some marshalling bugs in fsimpl/fuse. Updated multiple callpoints to get rid of redundant slice indexing work and simplified code using this new signature. Updates #6450 PiperOrigin-RevId: 407857019
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/control/control.go63
-rw-r--r--pkg/sentry/socket/control/control_vfs2.go5
-rw-r--r--pkg/sentry/socket/hostinet/socket.go14
-rw-r--r--pkg/sentry/socket/hostinet/stack.go6
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go5
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go5
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go5
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go7
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go2
-rw-r--r--pkg/sentry/socket/netfilter/targets.go15
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go2
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go2
-rw-r--r--pkg/sentry/socket/netlink/message_test.go45
-rw-r--r--pkg/sentry/socket/netlink/socket.go2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go14
-rw-r--r--pkg/sentry/socket/socket.go14
16 files changed, 82 insertions, 124 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 6077b2150..4b036b323 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -17,6 +17,8 @@
package control
import (
+ "time"
+
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/bits"
"gvisor.dev/gvisor/pkg/context"
@@ -29,7 +31,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
- "time"
)
// SCMCredentials represents a SCM_CREDENTIALS socket control message.
@@ -63,10 +64,10 @@ type RightsFiles []*fs.File
// NewSCMRights creates a new SCM_RIGHTS socket control message representation
// using local sentry FDs.
-func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) {
+func NewSCMRights(t *kernel.Task, fds []primitive.Int32) (SCMRights, error) {
files := make(RightsFiles, 0, len(fds))
for _, fd := range fds {
- file := t.GetFile(fd)
+ file := t.GetFile(int32(fd))
if file == nil {
files.Release(t)
return nil, linuxerr.EBADF
@@ -486,26 +487,25 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint) (socket.ControlMessages, error) {
var (
cmsgs socket.ControlMessages
- fds linux.ControlMessageRights
+ fds []primitive.Int32
)
- for i := 0; i < len(buf); {
- if i+linux.SizeOfControlMessageHeader > len(buf) {
+ for len(buf) > 0 {
+ if linux.SizeOfControlMessageHeader > len(buf) {
return cmsgs, linuxerr.EINVAL
}
var h linux.ControlMessageHeader
- h.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageHeader])
+ buf = h.UnmarshalUnsafe(buf)
if h.Length < uint64(linux.SizeOfControlMessageHeader) {
return socket.ControlMessages{}, linuxerr.EINVAL
}
- if h.Length > uint64(len(buf)-i) {
- return socket.ControlMessages{}, linuxerr.EINVAL
- }
- i += linux.SizeOfControlMessageHeader
length := int(h.Length) - linux.SizeOfControlMessageHeader
+ if length > len(buf) {
+ return socket.ControlMessages{}, linuxerr.EINVAL
+ }
switch h.Level {
case linux.SOL_SOCKET:
@@ -518,11 +518,9 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
return socket.ControlMessages{}, linuxerr.EINVAL
}
- for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight {
- fds = append(fds, int32(hostarch.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
- }
-
- i += bits.AlignUp(length, width)
+ curFDs := make([]primitive.Int32, numRights)
+ primitive.UnmarshalUnsafeInt32Slice(curFDs, buf[:rightsSize])
+ fds = append(fds, curFDs...)
case linux.SCM_CREDENTIALS:
if length < linux.SizeOfControlMessageCredentials {
@@ -530,23 +528,21 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
}
var creds linux.ControlMessageCredentials
- creds.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageCredentials])
+ creds.UnmarshalUnsafe(buf)
scmCreds, err := NewSCMCredentials(t, creds)
if err != nil {
return socket.ControlMessages{}, err
}
cmsgs.Unix.Credentials = scmCreds
- i += bits.AlignUp(length, width)
case linux.SO_TIMESTAMP:
if length < linux.SizeOfTimeval {
return socket.ControlMessages{}, linuxerr.EINVAL
}
var ts linux.Timeval
- ts.UnmarshalUnsafe(buf[i : i+linux.SizeOfTimeval])
+ ts.UnmarshalUnsafe(buf)
cmsgs.IP.Timestamp = ts.ToTime()
cmsgs.IP.HasTimestamp = true
- i += bits.AlignUp(length, width)
default:
// Unknown message type.
@@ -560,9 +556,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
}
cmsgs.IP.HasTOS = true
var tos primitive.Uint8
- tos.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTOS])
+ tos.UnmarshalUnsafe(buf)
cmsgs.IP.TOS = uint8(tos)
- i += bits.AlignUp(length, width)
case linux.IP_PKTINFO:
if length < linux.SizeOfControlMessageIPPacketInfo {
@@ -571,19 +566,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
cmsgs.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
- packetInfo.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageIPPacketInfo])
-
+ packetInfo.UnmarshalUnsafe(buf)
cmsgs.IP.PacketInfo = packetInfo
- i += bits.AlignUp(length, width)
case linux.IP_RECVORIGDSTADDR:
var addr linux.SockAddrInet
if length < addr.SizeBytes() {
return socket.ControlMessages{}, linuxerr.EINVAL
}
- addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()])
+ addr.UnmarshalUnsafe(buf)
cmsgs.IP.OriginalDstAddress = &addr
- i += bits.AlignUp(length, width)
case linux.IP_RECVERR:
var errCmsg linux.SockErrCMsgIPv4
@@ -591,9 +583,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
return socket.ControlMessages{}, linuxerr.EINVAL
}
- errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
+ errCmsg.UnmarshalBytes(buf)
cmsgs.IP.SockErr = &errCmsg
- i += bits.AlignUp(length, width)
default:
return socket.ControlMessages{}, linuxerr.EINVAL
@@ -606,18 +597,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
}
cmsgs.IP.HasTClass = true
var tclass primitive.Uint32
- tclass.UnmarshalUnsafe(buf[i : i+linux.SizeOfControlMessageTClass])
+ tclass.UnmarshalUnsafe(buf)
cmsgs.IP.TClass = uint32(tclass)
- i += bits.AlignUp(length, width)
case linux.IPV6_RECVORIGDSTADDR:
var addr linux.SockAddrInet6
if length < addr.SizeBytes() {
return socket.ControlMessages{}, linuxerr.EINVAL
}
- addr.UnmarshalUnsafe(buf[i : i+addr.SizeBytes()])
+ addr.UnmarshalUnsafe(buf)
cmsgs.IP.OriginalDstAddress = &addr
- i += bits.AlignUp(length, width)
case linux.IPV6_RECVERR:
var errCmsg linux.SockErrCMsgIPv6
@@ -625,9 +614,8 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
return socket.ControlMessages{}, linuxerr.EINVAL
}
- errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
+ errCmsg.UnmarshalBytes(buf)
cmsgs.IP.SockErr = &errCmsg
- i += bits.AlignUp(length, width)
default:
return socket.ControlMessages{}, linuxerr.EINVAL
@@ -635,6 +623,11 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte, width uint)
default:
return socket.ControlMessages{}, linuxerr.EINVAL
}
+ if shift := bits.AlignUp(length, width); shift > len(buf) {
+ buf = buf[:0]
+ } else {
+ buf = buf[shift:]
+ }
}
if cmsgs.Unix.Credentials == nil {
diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go
index 0a989cbeb..a638cb955 100644
--- a/pkg/sentry/socket/control/control_vfs2.go
+++ b/pkg/sentry/socket/control/control_vfs2.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -45,10 +46,10 @@ type RightsFilesVFS2 []*vfs.FileDescription
// NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message
// representation using local sentry FDs.
-func NewSCMRightsVFS2(t *kernel.Task, fds []int32) (SCMRightsVFS2, error) {
+func NewSCMRightsVFS2(t *kernel.Task, fds []primitive.Int32) (SCMRightsVFS2, error) {
files := make(RightsFilesVFS2, 0, len(fds))
for _, fd := range fds {
- file := t.GetFileVFS2(fd)
+ file := t.GetFileVFS2(int32(fd))
if file == nil {
files.Release(t)
return nil, linuxerr.EBADF
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 6e2318f75..a31f3ebec 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -578,7 +578,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
case linux.SO_TIMESTAMP:
controlMessages.IP.HasTimestamp = true
ts := linux.Timeval{}
- ts.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfTimeval])
+ ts.UnmarshalUnsafe(unixCmsg.Data)
controlMessages.IP.Timestamp = ts.ToTime()
}
@@ -587,18 +587,18 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
case linux.IP_TOS:
controlMessages.IP.HasTOS = true
var tos primitive.Uint8
- tos.UnmarshalUnsafe(unixCmsg.Data[:tos.SizeBytes()])
+ tos.UnmarshalUnsafe(unixCmsg.Data)
controlMessages.IP.TOS = uint8(tos)
case linux.IP_PKTINFO:
controlMessages.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
- packetInfo.UnmarshalUnsafe(unixCmsg.Data[:packetInfo.SizeBytes()])
+ packetInfo.UnmarshalUnsafe(unixCmsg.Data)
controlMessages.IP.PacketInfo = packetInfo
case linux.IP_RECVORIGDSTADDR:
var addr linux.SockAddrInet
- addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()])
+ addr.UnmarshalUnsafe(unixCmsg.Data)
controlMessages.IP.OriginalDstAddress = &addr
case unix.IP_RECVERR:
@@ -612,12 +612,12 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
case linux.IPV6_TCLASS:
controlMessages.IP.HasTClass = true
var tclass primitive.Uint32
- tclass.UnmarshalUnsafe(unixCmsg.Data[:tclass.SizeBytes()])
+ tclass.UnmarshalUnsafe(unixCmsg.Data)
controlMessages.IP.TClass = uint32(tclass)
case linux.IPV6_RECVORIGDSTADDR:
var addr linux.SockAddrInet6
- addr.UnmarshalUnsafe(unixCmsg.Data[:addr.SizeBytes()])
+ addr.UnmarshalUnsafe(unixCmsg.Data)
controlMessages.IP.OriginalDstAddress = &addr
case unix.IPV6_RECVERR:
@@ -631,7 +631,7 @@ func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) s
case linux.TCP_INQ:
controlMessages.IP.HasInq = true
var inq primitive.Int32
- inq.UnmarshalUnsafe(unixCmsg.Data[:linux.SizeOfControlMessageInq])
+ inq.UnmarshalUnsafe(unixCmsg.Data)
controlMessages.IP.Inq = int32(inq)
}
}
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index 61111ac6c..c84ab3fb7 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -138,7 +138,7 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli
return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), unix.SizeofIfInfomsg)
}
var ifinfo linux.InterfaceInfoMessage
- ifinfo.UnmarshalUnsafe(link.Data[:ifinfo.SizeBytes()])
+ ifinfo.UnmarshalUnsafe(link.Data)
inetIF := inet.Interface{
DeviceType: ifinfo.Type,
Flags: ifinfo.Flags,
@@ -169,7 +169,7 @@ func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.Netli
return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), unix.SizeofIfAddrmsg)
}
var ifaddr linux.InterfaceAddrMessage
- ifaddr.UnmarshalUnsafe(addr.Data[:ifaddr.SizeBytes()])
+ ifaddr.UnmarshalUnsafe(addr.Data)
inetAddr := inet.InterfaceAddr{
Family: ifaddr.Family,
PrefixLen: ifaddr.PrefixLen,
@@ -201,7 +201,7 @@ func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error)
}
var ifRoute linux.RouteMessage
- ifRoute.UnmarshalUnsafe(routeMsg.Data[:ifRoute.SizeBytes()])
+ ifRoute.UnmarshalUnsafe(routeMsg.Data)
inetRoute := inet.Route{
Family: ifRoute.Family,
DstLen: ifRoute.DstLen,
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index 3f1b4a17b..7606d2bbb 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -80,9 +80,8 @@ func marshalEntryMatch(name string, data []byte) []byte {
copy(matcher.Name[:], name)
buf := make([]byte, size)
- entryLen := matcher.XTEntryMatch.SizeBytes()
- matcher.XTEntryMatch.MarshalUnsafe(buf[:entryLen])
- copy(buf[entryLen:], matcher.Data)
+ bufRemain := matcher.XTEntryMatch.MarshalUnsafe(buf)
+ copy(bufRemain, matcher.Data)
return buf
}
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
index af31cbc5b..6cbfee8b6 100644
--- a/pkg/sentry/socket/netfilter/ipv4.go
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -141,10 +141,9 @@ func modifyEntries4(task *kernel.Task, stk *stack.Stack, optVal []byte, replace
nflog("optVal has insufficient size for entry %d", len(optVal))
return nil, syserr.ErrInvalidArgument
}
- var entry linux.IPTEntry
- entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()])
initialOptValLen := len(optVal)
- optVal = optVal[entry.SizeBytes():]
+ var entry linux.IPTEntry
+ optVal = entry.UnmarshalUnsafe(optVal)
if entry.TargetOffset < linux.SizeOfIPTEntry {
nflog("entry has too-small target offset %d", entry.TargetOffset)
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
index 6cefe0b9c..902707abf 100644
--- a/pkg/sentry/socket/netfilter/ipv6.go
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -144,10 +144,9 @@ func modifyEntries6(task *kernel.Task, stk *stack.Stack, optVal []byte, replace
nflog("optVal has insufficient size for entry %d", len(optVal))
return nil, syserr.ErrInvalidArgument
}
- var entry linux.IP6TEntry
- entry.UnmarshalUnsafe(optVal[:entry.SizeBytes()])
initialOptValLen := len(optVal)
- optVal = optVal[entry.SizeBytes():]
+ var entry linux.IP6TEntry
+ optVal = entry.UnmarshalUnsafe(optVal)
if entry.TargetOffset < linux.SizeOfIP6TEntry {
nflog("entry has too-small target offset %d", entry.TargetOffset)
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 01f2f8c77..f2cdf091d 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -176,9 +176,7 @@ func setHooksAndUnderflow(info *linux.IPTGetinfo, table stack.Table, offset uint
// net/ipv4/netfilter/ip_tables.c:translate_table for reference.
func SetEntries(task *kernel.Task, stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
var replace linux.IPTReplace
- replaceBuf := optVal[:linux.SizeOfIPTReplace]
- optVal = optVal[linux.SizeOfIPTReplace:]
- replace.UnmarshalBytes(replaceBuf)
+ optVal = replace.UnmarshalBytes(optVal)
var table stack.Table
switch replace.Name.String() {
@@ -306,8 +304,7 @@ func parseMatchers(task *kernel.Task, filter stack.IPHeaderFilter, optVal []byte
return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal))
}
var match linux.XTEntryMatch
- buf := optVal[:match.SizeBytes()]
- match.UnmarshalUnsafe(buf)
+ match.UnmarshalUnsafe(optVal)
nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match)
// Check some invariants.
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
index 6eff2ae65..d83d7e535 100644
--- a/pkg/sentry/socket/netfilter/owner_matcher.go
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -73,7 +73,7 @@ func (ownerMarshaler) unmarshal(task *kernel.Task, buf []byte, filter stack.IPHe
// For alignment reasons, the match's total size may
// exceed what's strictly necessary to hold matchData.
var matchData linux.IPTOwnerInfo
- matchData.UnmarshalUnsafe(buf[:linux.SizeOfIPTOwnerInfo])
+ matchData.UnmarshalUnsafe(buf)
nflog("parsed IPTOwnerInfo: %+v", matchData)
var owner OwnerMatcher
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index b9c15daab..eaf601543 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -199,7 +199,7 @@ func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
return nil, syserr.ErrInvalidArgument
}
var standardTarget linux.XTStandardTarget
- standardTarget.UnmarshalUnsafe(buf[:standardTarget.SizeBytes()])
+ standardTarget.UnmarshalUnsafe(buf)
if standardTarget.Verdict < 0 {
// A Verdict < 0 indicates a non-jump verdict.
@@ -253,7 +253,6 @@ func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
return nil, syserr.ErrInvalidArgument
}
var errTgt linux.XTErrorTarget
- buf = buf[:linux.SizeOfXTErrorTarget]
errTgt.UnmarshalUnsafe(buf)
// Error targets are used in 2 cases:
@@ -316,7 +315,6 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (
}
var rt linux.XTRedirectTarget
- buf = buf[:linux.SizeOfXTRedirectTarget]
rt.UnmarshalUnsafe(buf)
// Copy linux.XTRedirectTarget to stack.RedirectTarget.
@@ -405,8 +403,7 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (tar
}
var natRange linux.NFNATRange
- buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize]
- natRange.UnmarshalUnsafe(buf)
+ natRange.UnmarshalUnsafe(buf[linux.SizeOfXTEntryTarget:])
// We don't support port or address ranges.
if natRange.MinAddr != natRange.MaxAddr {
@@ -477,7 +474,6 @@ func (*snatTargetMakerV4) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta
}
var st linux.XTSNATTarget
- buf = buf[:linux.SizeOfXTSNATTarget]
st.UnmarshalUnsafe(buf)
// Copy linux.XTSNATTarget to stack.SNATTarget.
@@ -557,8 +553,7 @@ func (*snatTargetMakerV6) unmarshal(buf []byte, filter stack.IPHeaderFilter) (ta
}
var natRange linux.NFNATRange
- buf = buf[linux.SizeOfXTEntryTarget:nfNATMarshalledSize]
- natRange.UnmarshalUnsafe(buf)
+ natRange.UnmarshalUnsafe(buf[linux.SizeOfXTEntryTarget:])
// TODO(gvisor.dev/issue/5697): Support port or address ranges.
if natRange.MinAddr != natRange.MaxAddr {
@@ -621,7 +616,9 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte, ipv6 bool) (stack.T
return nil, syserr.ErrInvalidArgument
}
var target linux.XTEntryTarget
- target.UnmarshalUnsafe(optVal[:target.SizeBytes()])
+ // Do not advance optVal as targetMake.unmarshal() may unmarshal
+ // XTEntryTarget again but with some added fields.
+ target.UnmarshalUnsafe(optVal)
return unmarshalTarget(target, filter, optVal)
}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index e5b73a976..a621a6a16 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -59,7 +59,7 @@ func (tcpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderF
// For alignment reasons, the match's total size may
// exceed what's strictly necessary to hold matchData.
var matchData linux.XTTCP
- matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()])
+ matchData.UnmarshalUnsafe(buf)
nflog("parseMatchers: parsed XTTCP: %+v", matchData)
if matchData.Option != 0 ||
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index aa72ee70c..2ca854764 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -59,7 +59,7 @@ func (udpMarshaler) unmarshal(_ *kernel.Task, buf []byte, filter stack.IPHeaderF
// For alignment reasons, the match's total size may exceed what's
// strictly necessary to hold matchData.
var matchData linux.XTUDP
- matchData.UnmarshalUnsafe(buf[:matchData.SizeBytes()])
+ matchData.UnmarshalUnsafe(buf)
nflog("parseMatchers: parsed XTUDP: %+v", matchData)
if matchData.InverseFlags != 0 {
diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go
index 968968469..1604b2792 100644
--- a/pkg/sentry/socket/netlink/message_test.go
+++ b/pkg/sentry/socket/netlink/message_test.go
@@ -25,33 +25,14 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/netlink"
)
-type dummyNetlinkMsg struct {
- marshal.StubMarshallable
- Foo uint16
-}
-
-func (*dummyNetlinkMsg) SizeBytes() int {
- return 2
-}
-
-func (m *dummyNetlinkMsg) MarshalUnsafe(dst []byte) {
- p := primitive.Uint16(m.Foo)
- p.MarshalUnsafe(dst)
-}
-
-func (m *dummyNetlinkMsg) UnmarshalUnsafe(src []byte) {
- var p primitive.Uint16
- p.UnmarshalUnsafe(src)
- m.Foo = uint16(p)
-}
-
func TestParseMessage(t *testing.T) {
+ dummyNetlinkMsg := primitive.Uint16(0x3130)
tests := []struct {
desc string
input []byte
header linux.NetlinkMessageHeader
- dataMsg *dummyNetlinkMsg
+ dataMsg marshal.Marshallable
restLen int
ok bool
}{
@@ -72,9 +53,7 @@ func TestParseMessage(t *testing.T) {
Seq: 3,
PortID: 4,
},
- dataMsg: &dummyNetlinkMsg{
- Foo: 0x3130,
- },
+ dataMsg: &dummyNetlinkMsg,
restLen: 0,
ok: true,
},
@@ -96,9 +75,7 @@ func TestParseMessage(t *testing.T) {
Seq: 3,
PortID: 4,
},
- dataMsg: &dummyNetlinkMsg{
- Foo: 0x3130,
- },
+ dataMsg: &dummyNetlinkMsg,
restLen: 1,
ok: true,
},
@@ -119,9 +96,7 @@ func TestParseMessage(t *testing.T) {
Seq: 3,
PortID: 4,
},
- dataMsg: &dummyNetlinkMsg{
- Foo: 0x3130,
- },
+ dataMsg: &dummyNetlinkMsg,
restLen: 0,
ok: true,
},
@@ -143,9 +118,7 @@ func TestParseMessage(t *testing.T) {
Seq: 3,
PortID: 4,
},
- dataMsg: &dummyNetlinkMsg{
- Foo: 0x3130,
- },
+ dataMsg: &dummyNetlinkMsg,
restLen: 0,
ok: true,
},
@@ -199,11 +172,11 @@ func TestParseMessage(t *testing.T) {
t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header)
}
- dataMsg := &dummyNetlinkMsg{}
- _, dataOk := msg.GetData(dataMsg)
+ var dataMsg primitive.Uint16
+ _, dataOk := msg.GetData(&dataMsg)
if !dataOk {
t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk)
- } else if !reflect.DeepEqual(dataMsg, test.dataMsg) {
+ } else if !reflect.DeepEqual(&dataMsg, test.dataMsg) {
t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg)
}
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 267155807..19c8f340d 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -223,7 +223,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
}
var sa linux.SockAddrNetlink
- sa.UnmarshalUnsafe(b[:sa.SizeBytes()])
+ sa.UnmarshalUnsafe(b)
if sa.Family != linux.AF_NETLINK {
return nil, syserr.ErrInvalidArgument
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index c35cf06f6..e38c4e5da 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -660,7 +660,7 @@ func (s *socketOpsCommon) Bind(_ *kernel.Task, sockaddr []byte) *syserr.Error {
if len(sockaddr) < sockAddrLinkSize {
return syserr.ErrInvalidArgument
}
- a.UnmarshalBytes(sockaddr[:sockAddrLinkSize])
+ a.UnmarshalBytes(sockaddr)
addr = tcpip.FullAddress{
NIC: tcpip.NICID(a.InterfaceIndex),
@@ -1839,7 +1839,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
var v linux.Timeval
- v.UnmarshalBytes(optVal[:linux.SizeOfTimeval])
+ v.UnmarshalBytes(optVal)
if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
return syserr.ErrDomain
}
@@ -1852,7 +1852,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
var v linux.Timeval
- v.UnmarshalBytes(optVal[:linux.SizeOfTimeval])
+ v.UnmarshalBytes(optVal)
if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
return syserr.ErrDomain
}
@@ -1883,7 +1883,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
var v linux.Linger
- v.UnmarshalBytes(optVal[:linux.SizeOfLinger])
+ v.UnmarshalBytes(optVal)
if v != (linux.Linger{}) {
socket.SetSockOptEmitUnimplementedEvent(t, name)
@@ -2222,12 +2222,12 @@ func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastR
if len(optVal) >= inetMulticastRequestWithNICSize {
var req linux.InetMulticastRequestWithNIC
- req.UnmarshalUnsafe(optVal[:inetMulticastRequestWithNICSize])
+ req.UnmarshalUnsafe(optVal)
return req, nil
}
var req linux.InetMulticastRequestWithNIC
- req.InetMulticastRequest.UnmarshalUnsafe(optVal[:inetMulticastRequestSize])
+ req.InetMulticastRequest.UnmarshalUnsafe(optVal)
return req, nil
}
@@ -2237,7 +2237,7 @@ func copyInMulticastV6Request(optVal []byte) (linux.Inet6MulticastRequest, *syse
}
var req linux.Inet6MulticastRequest
- req.UnmarshalUnsafe(optVal[:inet6MulticastRequestSize])
+ req.UnmarshalUnsafe(optVal)
return req, nil
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index fc5431eb1..01073df72 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -595,19 +595,19 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr {
switch family {
case unix.AF_INET:
var addr linux.SockAddrInet
- addr.UnmarshalUnsafe(data[:addr.SizeBytes()])
+ addr.UnmarshalUnsafe(data)
return &addr
case unix.AF_INET6:
var addr linux.SockAddrInet6
- addr.UnmarshalUnsafe(data[:addr.SizeBytes()])
+ addr.UnmarshalUnsafe(data)
return &addr
case unix.AF_UNIX:
var addr linux.SockAddrUnix
- addr.UnmarshalUnsafe(data[:addr.SizeBytes()])
+ addr.UnmarshalUnsafe(data)
return &addr
case unix.AF_NETLINK:
var addr linux.SockAddrNetlink
- addr.UnmarshalUnsafe(data[:addr.SizeBytes()])
+ addr.UnmarshalUnsafe(data)
return &addr
default:
panic(fmt.Sprintf("Unsupported socket family %v", family))
@@ -738,7 +738,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
if len(addr) < sockAddrInetSize {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- a.UnmarshalUnsafe(addr[:sockAddrInetSize])
+ a.UnmarshalUnsafe(addr)
out := tcpip.FullAddress{
Addr: BytesToIPAddress(a.Addr[:]),
@@ -751,7 +751,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
if len(addr) < sockAddrInet6Size {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- a.UnmarshalUnsafe(addr[:sockAddrInet6Size])
+ a.UnmarshalUnsafe(addr)
out := tcpip.FullAddress{
Addr: BytesToIPAddress(a.Addr[:]),
@@ -767,7 +767,7 @@ func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
if len(addr) < sockAddrLinkSize {
return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
}
- a.UnmarshalUnsafe(addr[:sockAddrLinkSize])
+ a.UnmarshalUnsafe(addr)
// TODO(https://gvisor.dev/issue/6530): Do not assume all interfaces have
// an ethernet address.
if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {