summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/netstack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/netstack')
-rw-r--r--pkg/sentry/socket/netstack/BUILD2
-rw-r--r--pkg/sentry/socket/netstack/netstack.go224
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go16
3 files changed, 160 insertions, 82 deletions
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index ea6ebd0e2..1fb777a6c 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -51,6 +51,8 @@ go_library(
"//pkg/tcpip/transport/udp",
"//pkg/usermem",
"//pkg/waiter",
+ "//tools/go_marshal/marshal",
+ "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 49a04e613..9856ab8c5 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -26,6 +26,7 @@ package netstack
import (
"bytes"
+ "fmt"
"io"
"math"
"reflect"
@@ -61,6 +62,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
func mustCreateMetric(name, description string) *tcpip.StatCounter {
@@ -909,7 +912,7 @@ func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for netstack.SocketOperations rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -919,25 +922,25 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptTimestamp {
val = 1
}
- return val, nil
+ return &val, nil
}
if level == linux.SOL_TCP && name == linux.TCP_INQ {
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptInq {
val = 1
}
- return val, nil
+ return &val, nil
}
if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
@@ -955,7 +958,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if err != nil {
return nil, err
}
- return info, nil
+ return &info, nil
case linux.IPT_SO_GET_ENTRIES:
if outLen < linux.SizeOfIPTGetEntries {
@@ -970,7 +973,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
if err != nil {
return nil, err
}
- return entries, nil
+ return &entries, nil
}
}
@@ -980,7 +983,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
// sockets backed by a commonEndpoint.
-func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
@@ -1013,7 +1016,7 @@ func boolToInt32(v bool) int32 {
}
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_ERROR:
@@ -1024,9 +1027,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
// Get the last error and convert it.
err := ep.GetSockOpt(tcpip.ErrorOption{})
if err == nil {
- return int32(0), nil
+ optP := primitive.Int32(0)
+ return &optP, nil
}
- return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil
+
+ optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number())
+ return &optP, nil
case linux.SO_PEERCRED:
if family != linux.AF_UNIX || outLen < syscall.SizeofUcred {
@@ -1034,11 +1040,12 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
tcred := t.Credentials()
- return syscall.Ucred{
- Pid: int32(t.ThreadGroup().ID()),
- Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
- Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
- }, nil
+ creds := linux.ControlMessageCredentials{
+ PID: int32(t.ThreadGroup().ID()),
+ UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
+ GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
+ }
+ return &creds, nil
case linux.SO_PASSCRED:
if outLen < sizeOfInt32 {
@@ -1049,7 +1056,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_SNDBUF:
if outLen < sizeOfInt32 {
@@ -1065,7 +1074,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
size = math.MaxInt32
}
- return int32(size), nil
+ sizeP := primitive.Int32(size)
+ return &sizeP, nil
case linux.SO_RCVBUF:
if outLen < sizeOfInt32 {
@@ -1081,7 +1091,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
size = math.MaxInt32
}
- return int32(size), nil
+ sizeP := primitive.Int32(size)
+ return &sizeP, nil
case linux.SO_REUSEADDR:
if outLen < sizeOfInt32 {
@@ -1092,7 +1103,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_REUSEPORT:
if outLen < sizeOfInt32 {
@@ -1103,7 +1115,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_BINDTODEVICE:
var v tcpip.BindToDeviceOption
@@ -1111,7 +1125,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.TranslateNetstackError(err)
}
if v == 0 {
- return []byte{}, nil
+ var b primitive.ByteSlice
+ return &b, nil
}
if outLen < linux.IFNAMSIZ {
return nil, syserr.ErrInvalidArgument
@@ -1126,7 +1141,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
// interface was removed.
return nil, syserr.ErrUnknownDevice
}
- return append([]byte(nic.Name), 0), nil
+
+ name := primitive.ByteSlice(append([]byte(nic.Name), 0))
+ return &name, nil
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
@@ -1137,7 +1154,9 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_KEEPALIVE:
if outLen < sizeOfInt32 {
@@ -1148,13 +1167,17 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.SO_LINGER:
if outLen < linux.SizeOfLinger {
return nil, syserr.ErrInvalidArgument
}
- return linux.Linger{}, nil
+
+ linger := linux.Linger{}
+ return &linger, nil
case linux.SO_SNDTIMEO:
// TODO(igudger): Linux allows shorter lengths for partial results.
@@ -1162,7 +1185,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- return linux.NsecToTimeval(s.SendTimeout()), nil
+ sendTimeout := linux.NsecToTimeval(s.SendTimeout())
+ return &sendTimeout, nil
case linux.SO_RCVTIMEO:
// TODO(igudger): Linux allows shorter lengths for partial results.
@@ -1170,7 +1194,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- return linux.NsecToTimeval(s.RecvTimeout()), nil
+ recvTimeout := linux.NsecToTimeval(s.RecvTimeout())
+ return &recvTimeout, nil
case linux.SO_OOBINLINE:
if outLen < sizeOfInt32 {
@@ -1182,7 +1207,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.SO_NO_CHECK:
if outLen < sizeOfInt32 {
@@ -1193,7 +1219,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
@@ -1202,7 +1229,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// getSockOptTCP implements GetSockOpt when level is SOL_TCP.
-func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.TCP_NODELAY:
if outLen < sizeOfInt32 {
@@ -1213,7 +1240,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(!v), nil
+
+ vP := primitive.Int32(boolToInt32(!v))
+ return &vP, nil
case linux.TCP_CORK:
if outLen < sizeOfInt32 {
@@ -1224,7 +1253,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.TCP_QUICKACK:
if outLen < sizeOfInt32 {
@@ -1235,7 +1266,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.TCP_MAXSEG:
if outLen < sizeOfInt32 {
@@ -1246,8 +1279,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_KEEPIDLE:
if outLen < sizeOfInt32 {
@@ -1258,8 +1291,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Second), nil
+ keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second)
+ return &keepAliveIdle, nil
case linux.TCP_KEEPINTVL:
if outLen < sizeOfInt32 {
@@ -1270,8 +1303,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Second), nil
+ keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second)
+ return &keepAliveInterval, nil
case linux.TCP_KEEPCNT:
if outLen < sizeOfInt32 {
@@ -1282,8 +1315,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_USER_TIMEOUT:
if outLen < sizeOfInt32 {
@@ -1294,8 +1327,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(time.Duration(v) / time.Millisecond), nil
+ tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond)
+ return &tcpUserTimeout, nil
case linux.TCP_INFO:
var v tcpip.TCPInfoOption
@@ -1308,12 +1341,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
info := linux.TCPInfo{}
// Linux truncates the output binary to outLen.
- ib := binary.Marshal(nil, usermem.ByteOrder, &info)
- if len(ib) > outLen {
- ib = ib[:outLen]
+ buf := t.CopyScratchBuffer(info.SizeBytes())
+ info.MarshalUnsafe(buf)
+ if len(buf) > outLen {
+ buf = buf[:outLen]
}
-
- return ib, nil
+ bufP := primitive.ByteSlice(buf)
+ return &bufP, nil
case linux.TCP_CC_INFO,
linux.TCP_NOTSENT_LOWAT,
@@ -1343,7 +1377,9 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
}
b := make([]byte, toCopy)
copy(b, v)
- return b, nil
+
+ bP := primitive.ByteSlice(b)
+ return &bP, nil
case linux.TCP_LINGER2:
if outLen < sizeOfInt32 {
@@ -1355,7 +1391,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.TranslateNetstackError(err)
}
- return int32(time.Duration(v) / time.Second), nil
+ lingerTimeout := primitive.Int32(time.Duration(v) / time.Second)
+ return &lingerTimeout, nil
case linux.TCP_DEFER_ACCEPT:
if outLen < sizeOfInt32 {
@@ -1367,7 +1404,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.TranslateNetstackError(err)
}
- return int32(time.Duration(v) / time.Second), nil
+ tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second)
+ return &tcpDeferAccept, nil
case linux.TCP_SYNCNT:
if outLen < sizeOfInt32 {
@@ -1378,8 +1416,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.TCP_WINDOW_CLAMP:
if outLen < sizeOfInt32 {
@@ -1390,8 +1428,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
default:
emitUnimplementedEventTCP(t, name)
}
@@ -1399,7 +1437,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
}
// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
-func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IPV6_V6ONLY:
if outLen < sizeOfInt32 {
@@ -1410,7 +1448,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1418,21 +1458,24 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
case linux.IPV6_TCLASS:
// Length handling for parity with Linux.
if outLen == 0 {
- return make([]byte, 0), nil
+ var b primitive.ByteSlice
+ return &b, nil
}
v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- uintv := uint32(v)
+ uintv := primitive.Uint32(v)
// Linux truncates the output binary to outLen.
- ib := binary.Marshal(nil, usermem.ByteOrder, &uintv)
+ ib := t.CopyScratchBuffer(uintv.SizeBytes())
+ uintv.MarshalUnsafe(ib)
// Handle cases where outLen is lesser than sizeOfInt32.
if len(ib) > outLen {
ib = ib[:outLen]
}
- return ib, nil
+ ibP := primitive.ByteSlice(ib)
+ return &ibP, nil
case linux.IPV6_RECVTCLASS:
if outLen < sizeOfInt32 {
@@ -1443,7 +1486,9 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
emitUnimplementedEventIPv6(t, name)
@@ -1452,7 +1497,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
}
// getSockOptIP implements GetSockOpt when level is SOL_IP.
-func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) {
+func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IP_TTL:
if outLen < sizeOfInt32 {
@@ -1465,11 +1510,12 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
}
// Fill in the default value, if needed.
- if v == 0 {
- v = DefaultTTL
+ vP := primitive.Int32(v)
+ if vP == 0 {
+ vP = DefaultTTL
}
- return int32(v), nil
+ return &vP, nil
case linux.IP_MULTICAST_TTL:
if outLen < sizeOfInt32 {
@@ -1481,7 +1527,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.TranslateNetstackError(err)
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.IP_MULTICAST_IF:
if outLen < len(linux.InetAddr{}) {
@@ -1495,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
- return a.(*linux.SockAddrInet).Addr, nil
+ return &a.(*linux.SockAddrInet).Addr, nil
case linux.IP_MULTICAST_LOOP:
if outLen < sizeOfInt32 {
@@ -1506,21 +1553,26 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IP_TOS:
// Length handling for parity with Linux.
if outLen == 0 {
- return []byte(nil), nil
+ var b primitive.ByteSlice
+ return &b, nil
}
v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
if outLen < sizeOfInt32 {
- return uint8(v), nil
+ vP := primitive.Uint8(v)
+ return &vP, nil
}
- return int32(v), nil
+ vP := primitive.Int32(v)
+ return &vP, nil
case linux.IP_RECVTOS:
if outLen < sizeOfInt32 {
@@ -1531,7 +1583,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
case linux.IP_PKTINFO:
if outLen < sizeOfInt32 {
@@ -1542,7 +1596,9 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- return boolToInt32(v), nil
+
+ vP := primitive.Int32(boolToInt32(v))
+ return &vP, nil
default:
emitUnimplementedEventIP(t, name)
@@ -2468,6 +2524,23 @@ func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed)
}
+func toLinuxPacketType(pktType tcpip.PacketType) uint8 {
+ switch pktType {
+ case tcpip.PacketHost:
+ return linux.PACKET_HOST
+ case tcpip.PacketOtherHost:
+ return linux.PACKET_OTHERHOST
+ case tcpip.PacketOutgoing:
+ return linux.PACKET_OUTGOING
+ case tcpip.PacketBroadcast:
+ return linux.PACKET_BROADCAST
+ case tcpip.PacketMulticast:
+ return linux.PACKET_MULTICAST
+ default:
+ panic(fmt.Sprintf("unknown packet type: %d", pktType))
+ }
+}
+
// nonBlockingRead issues a non-blocking read.
//
// TODO(b/78348848): Support timestamps for stream sockets.
@@ -2526,6 +2599,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
switch v := addr.(type) {
case *linux.SockAddrLink:
v.Protocol = htons(uint16(s.linkPacketInfo.Protocol))
+ v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType)
}
}
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index d65a89316..a9025b0ec 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -31,6 +31,8 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
+ "gvisor.dev/gvisor/tools/go_marshal/marshal"
+ "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// SocketVFS2 encapsulates all the state needed to represent a network stack
@@ -200,7 +202,7 @@ func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
// implemented specifically for netstack.SocketVFS2 rather than
// commonEndpoint. commonEndpoint should be extended to support socket
@@ -210,25 +212,25 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptTimestamp {
val = 1
}
- return val, nil
+ return &val, nil
}
if level == linux.SOL_TCP && name == linux.TCP_INQ {
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- val := int32(0)
+ val := primitive.Int32(0)
s.readMu.Lock()
defer s.readMu.Unlock()
if s.sockOptInq {
val = 1
}
- return val, nil
+ return &val, nil
}
if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
@@ -246,7 +248,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if err != nil {
return nil, err
}
- return info, nil
+ return &info, nil
case linux.IPT_SO_GET_ENTRIES:
if outLen < linux.SizeOfIPTGetEntries {
@@ -261,7 +263,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
if err != nil {
return nil, err
}
- return entries, nil
+ return &entries, nil
}
}