summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/control/BUILD1
-rw-r--r--pkg/sentry/socket/control/control.go117
-rw-r--r--pkg/sentry/socket/hostinet/socket.go179
-rw-r--r--pkg/sentry/socket/netstack/netstack.go209
-rw-r--r--pkg/sentry/socket/socket.go129
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go22
6 files changed, 501 insertions, 156 deletions
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index ca16d0381..fb7c5dc61 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -23,7 +23,6 @@ go_library(
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/vfs",
"//pkg/syserror",
- "//pkg/tcpip",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 70ccf77a7..ff6b71802 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -26,7 +26,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -344,18 +343,42 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
}
// PackIPPacketInfo packs an IP_PKTINFO socket control message.
-func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
- var p linux.ControlMessageIPPacketInfo
- p.NIC = int32(packetInfo.NIC)
- copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
- copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
-
+func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketInfo, buf []byte) []byte {
return putCmsgStruct(
buf,
linux.SOL_IP,
linux.IP_PKTINFO,
t.Arch().Width(),
- p,
+ packetInfo,
+ )
+}
+
+// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message.
+func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte {
+ var level uint32
+ var optType uint32
+ switch originalDstAddress.(type) {
+ case *linux.SockAddrInet:
+ level = linux.SOL_IP
+ optType = linux.IP_RECVORIGDSTADDR
+ case *linux.SockAddrInet6:
+ level = linux.SOL_IPV6
+ optType = linux.IPV6_RECVORIGDSTADDR
+ default:
+ panic("invalid address type, must be an IP address for IP_RECVORIGINALDSTADDR cmsg")
+ }
+ return putCmsgStruct(
+ buf, level, optType, t.Arch().Width(), originalDstAddress)
+}
+
+// PackSockExtendedErr packs an IP*_RECVERR socket control message.
+func PackSockExtendedErr(t *kernel.Task, sockErr linux.SockErrCMsg, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ sockErr.CMsgLevel(),
+ sockErr.CMsgType(),
+ t.Arch().Width(),
+ sockErr,
)
}
@@ -384,7 +407,15 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
}
if cmsgs.IP.HasIPPacketInfo {
- buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf)
+ }
+
+ if cmsgs.IP.OriginalDstAddress != nil {
+ buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf)
+ }
+
+ if cmsgs.IP.SockErr != nil {
+ buf = PackSockExtendedErr(t, cmsgs.IP.SockErr, buf)
}
return buf
@@ -416,17 +447,19 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
}
- return space
-}
+ if cmsgs.IP.HasIPPacketInfo {
+ space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo)
+ }
-// NewIPPacketInfo returns the IPPacketInfo struct.
-func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
- var p tcpip.IPPacketInfo
- p.NIC = tcpip.NICID(packetInfo.NIC)
- copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
- copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+ if cmsgs.IP.OriginalDstAddress != nil {
+ space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes())
+ }
- return p
+ if cmsgs.IP.SockErr != nil {
+ space += cmsgSpace(t, cmsgs.IP.SockErr.SizeBytes())
+ }
+
+ return space
}
// Parse parses a raw socket control message into portable objects.
@@ -489,6 +522,14 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
cmsgs.Unix.Credentials = scmCreds
i += binary.AlignUp(length, width)
+ case linux.SO_TIMESTAMP:
+ if length < linux.SizeOfTimeval {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ cmsgs.IP.HasTimestamp = true
+ binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &cmsgs.IP.Timestamp)
+ i += binary.AlignUp(length, width)
+
default:
// Unknown message type.
return socket.ControlMessages{}, syserror.EINVAL
@@ -512,7 +553,26 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
var packetInfo linux.ControlMessageIPPacketInfo
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
- cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ cmsgs.IP.PacketInfo = packetInfo
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet
+ if length < addr.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr)
+ cmsgs.IP.OriginalDstAddress = &addr
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_RECVERR:
+ var errCmsg linux.SockErrCMsgIPv4
+ if length < errCmsg.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
+ cmsgs.IP.SockErr = &errCmsg
i += binary.AlignUp(length, width)
default:
@@ -528,6 +588,25 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
i += binary.AlignUp(length, width)
+ case linux.IPV6_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet6
+ if length < addr.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr)
+ cmsgs.IP.OriginalDstAddress = &addr
+ i += binary.AlignUp(length, width)
+
+ case linux.IPV6_RECVERR:
+ var errCmsg linux.SockErrCMsgIPv6
+ if length < errCmsg.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
+ cmsgs.IP.SockErr = &errCmsg
+ i += binary.AlignUp(length, width)
+
default:
return socket.ControlMessages{}, syserror.EINVAL
}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 7d3c4a01c..5b868216d 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -331,17 +331,17 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR:
optlen = sizeofInt32
}
case linux.SOL_IPV6:
switch name {
- case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_SOCKET:
switch name {
- case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP:
optlen = sizeofInt32
case linux.SO_LINGER:
optlen = syscall.SizeofLinger
@@ -377,24 +377,24 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR:
optlen = sizeofInt32
case linux.IP_PKTINFO:
optlen = linux.SizeOfControlMessageIPPacketInfo
}
case linux.SOL_IPV6:
switch name {
- case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_SOCKET:
switch name {
- case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP:
optlen = sizeofInt32
}
case linux.SOL_TCP:
switch name {
- case linux.TCP_NODELAY:
+ case linux.TCP_NODELAY, linux.TCP_INQ:
optlen = sizeofInt32
}
}
@@ -416,68 +416,76 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
return nil
}
-// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
- // Only allow known and safe flags.
- //
- // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary
- // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the
- // Socket interface's dependence on netstack.
- if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 {
- return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
- }
+func (s *socketOpsCommon) recvMsgFromHost(iovs []syscall.Iovec, flags int, senderRequested bool, controlLen uint64) (uint64, int, []byte, []byte, error) {
+ // We always do a non-blocking recv*().
+ sysflags := flags | syscall.MSG_DONTWAIT
- var senderAddr linux.SockAddr
+ msg := syscall.Msghdr{}
+ if len(iovs) > 0 {
+ msg.Iov = &iovs[0]
+ msg.Iovlen = uint64(len(iovs))
+ }
var senderAddrBuf []byte
if senderRequested {
senderAddrBuf = make([]byte, sizeofSockaddr)
+ msg.Name = &senderAddrBuf[0]
+ msg.Namelen = uint32(sizeofSockaddr)
}
-
var controlBuf []byte
- var msgFlags int
-
- recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
- // Refuse to do anything if any part of dst.Addrs was unusable.
- if uint64(dst.NumBytes()) != dsts.NumBytes() {
- return 0, nil
- }
- if dsts.IsEmpty() {
- return 0, nil
+ if controlLen > 0 {
+ if controlLen > maxControlLen {
+ controlLen = maxControlLen
}
+ controlBuf = make([]byte, controlLen)
+ msg.Control = &controlBuf[0]
+ msg.Controllen = controlLen
+ }
+ n, err := recvmsg(s.fd, &msg, sysflags)
+ if err != nil {
+ return 0 /* n */, 0 /* mFlags */, nil /* senderAddrBuf */, nil /* controlBuf */, err
+ }
+ return n, int(msg.Flags), senderAddrBuf[:msg.Namelen], controlBuf[:msg.Controllen], err
+}
- // We always do a non-blocking recv*().
- sysflags := flags | syscall.MSG_DONTWAIT
+// RecvMsg implements socket.Socket.RecvMsg.
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ // Only allow known and safe flags.
+ if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC|syscall.MSG_ERRQUEUE) != 0 {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
+ }
- iovs := safemem.IovecsFromBlockSeq(dsts)
- msg := syscall.Msghdr{
- Iov: &iovs[0],
- Iovlen: uint64(len(iovs)),
- }
- if len(senderAddrBuf) != 0 {
- msg.Name = &senderAddrBuf[0]
- msg.Namelen = uint32(len(senderAddrBuf))
- }
- if controlLen > 0 {
- if controlLen > maxControlLen {
- controlLen = maxControlLen
+ var senderAddrBuf []byte
+ var controlBuf []byte
+ var msgFlags int
+ copyToDst := func() (int64, error) {
+ var n uint64
+ var err error
+ if dst.NumBytes() == 0 {
+ // We want to make the recvmsg(2) call to the host even if dst is empty
+ // to fetch control messages, sender address or errors if any occur.
+ n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(nil, flags, senderRequested, controlLen)
+ return int64(n), err
+ }
+
+ recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of dst.Addrs was unusable.
+ if uint64(dst.NumBytes()) != dsts.NumBytes() {
+ return 0, nil
}
- controlBuf = make([]byte, controlLen)
- msg.Control = &controlBuf[0]
- msg.Controllen = controlLen
- }
- n, err := recvmsg(s.fd, &msg, sysflags)
- if err != nil {
- return 0, err
- }
- senderAddrBuf = senderAddrBuf[:msg.Namelen]
- msgFlags = int(msg.Flags)
- controlLen = uint64(msg.Controllen)
- return n, nil
- })
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+
+ n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(safemem.IovecsFromBlockSeq(dsts), flags, senderRequested, controlLen)
+ return n, err
+ })
+ return dst.CopyOutFrom(t, recvmsgToBlocks)
+ }
var ch chan struct{}
- n, err := dst.CopyOutFrom(t, recvmsgToBlocks)
- if flags&syscall.MSG_DONTWAIT == 0 {
+ n, err := copyToDst()
+ // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT.
+ if flags&(syscall.MSG_DONTWAIT|syscall.MSG_ERRQUEUE) == 0 {
for err == syserror.ErrWouldBlock {
// We only expect blocking to come from the actual syscall, in which
// case it can't have returned any data.
@@ -494,48 +502,85 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
s.EventRegister(&e, waiter.EventIn)
defer s.EventUnregister(&e)
}
- n, err = dst.CopyOutFrom(t, recvmsgToBlocks)
+ n, err = copyToDst()
}
}
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
+ var senderAddr linux.SockAddr
if senderRequested {
senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf)
}
- unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen])
+ unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf)
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
+ return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), parseUnixControlMessages(unixControlMessages), nil
+}
+func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) socket.ControlMessages {
controlMessages := socket.ControlMessages{}
for _, unixCmsg := range unixControlMessages {
switch unixCmsg.Header.Level {
- case syscall.SOL_IP:
+ case linux.SOL_SOCKET:
switch unixCmsg.Header.Type {
- case syscall.IP_TOS:
+ case linux.SO_TIMESTAMP:
+ controlMessages.IP.HasTimestamp = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &controlMessages.IP.Timestamp)
+ }
+
+ case linux.SOL_IP:
+ switch unixCmsg.Header.Type {
+ case linux.IP_TOS:
controlMessages.IP.HasTOS = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
- case syscall.IP_PKTINFO:
+ case linux.IP_PKTINFO:
controlMessages.IP.HasIPPacketInfo = true
var packetInfo linux.ControlMessageIPPacketInfo
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
- controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo)
+ controlMessages.IP.PacketInfo = packetInfo
+
+ case linux.IP_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet
+ binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr)
+ controlMessages.IP.OriginalDstAddress = &addr
+
+ case syscall.IP_RECVERR:
+ var errCmsg linux.SockErrCMsgIPv4
+ errCmsg.UnmarshalBytes(unixCmsg.Data)
+ controlMessages.IP.SockErr = &errCmsg
}
- case syscall.SOL_IPV6:
+ case linux.SOL_IPV6:
switch unixCmsg.Header.Type {
- case syscall.IPV6_TCLASS:
+ case linux.IPV6_TCLASS:
controlMessages.IP.HasTClass = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass)
+
+ case linux.IPV6_RECVORIGDSTADDR:
+ var addr linux.SockAddrInet6
+ binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr)
+ controlMessages.IP.OriginalDstAddress = &addr
+
+ case syscall.IPV6_RECVERR:
+ var errCmsg linux.SockErrCMsgIPv6
+ errCmsg.UnmarshalBytes(unixCmsg.Data)
+ controlMessages.IP.SockErr = &errCmsg
+ }
+
+ case linux.SOL_TCP:
+ switch unixCmsg.Header.Type {
+ case linux.TCP_INQ:
+ controlMessages.IP.HasInq = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], usermem.ByteOrder, &controlMessages.IP.Inq)
}
}
}
-
- return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil
+ return controlMessages
}
// SendMsg implements socket.Socket.SendMsg.
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 3cc0d4f0f..3f587638f 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -320,7 +320,7 @@ type socketOpsCommon struct {
readView buffer.View
// readCM holds control message information for the last packet read
// from Endpoint.
- readCM tcpip.ControlMessages
+ readCM socket.IPControlMessages
sender tcpip.FullAddress
linkPacketInfo tcpip.LinkPacketInfo
@@ -408,7 +408,7 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error {
}
s.readView = v
- s.readCM = cms
+ s.readCM = socket.NewIPControlMessages(s.family, cms)
atomic.StoreUint32(&s.readViewHasData, 1)
return nil
@@ -428,11 +428,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) {
return
}
- var v tcpip.LingerOption
- if err := s.Endpoint.GetSockOpt(&v); err != nil {
- return
- }
-
+ v := s.Endpoint.SocketOptions().GetLinger()
// The case for zero timeout is handled in tcp endpoint close function.
// Close is blocked until either:
// 1. The endpoint state is not in any of the states: FIN-WAIT1,
@@ -965,7 +961,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// Get the last error and convert it.
- err := ep.LastError()
+ err := ep.SocketOptions().GetLastError()
if err == nil {
optP := primitive.Int32(0)
return &optP, nil
@@ -1046,10 +1042,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return &v, nil
case linux.SO_BINDTODEVICE:
- var v tcpip.BindToDeviceOption
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ v := ep.SocketOptions().GetBindToDevice()
if v == 0 {
var b primitive.ByteSlice
return &b, nil
@@ -1092,11 +1085,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.LingerOption
var linger linux.Linger
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ v := ep.SocketOptions().GetLinger()
if v.Enabled {
linger.OnOff = 1
@@ -1127,13 +1117,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.OutOfBandInlineOption
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(v)
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetOutOfBandInline()))
+ return &v, nil
case linux.SO_NO_CHECK:
if outLen < sizeOfInt32 {
@@ -1417,6 +1402,21 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass()))
return &v, nil
+ case linux.IPV6_RECVERR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError()))
+ return &v, nil
+
+ case linux.IPV6_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
case linux.IP6T_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet6{})) {
@@ -1583,6 +1583,14 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS()))
return &v, nil
+ case linux.IP_RECVERR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError()))
+ return &v, nil
+
case linux.IP_PKTINFO:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
@@ -1599,6 +1607,14 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded()))
return &v, nil
+ case linux.IP_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
+
case linux.SO_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet{})) {
return nil, syserr.ErrInvalidArgument
@@ -1785,8 +1801,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
name := string(optVal[:n])
if name == "" {
- v := tcpip.BindToDeviceOption(0)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&v))
+ return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(0))
}
s := t.NetworkContext()
if s == nil {
@@ -1794,8 +1809,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
for nicID, nic := range s.Interfaces() {
if nic.Name == name {
- v := tcpip.BindToDeviceOption(nicID)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&v))
+ return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(nicID))
}
}
return syserr.ErrUnknownDevice
@@ -1864,8 +1878,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- opt := tcpip.OutOfBandInlineOption(v)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&opt))
+ ep.SocketOptions().SetOutOfBandInline(v != 0)
+ return nil
case linux.SO_NO_CHECK:
if len(optVal) < sizeOfInt32 {
@@ -1888,10 +1902,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- return syserr.TranslateNetstackError(
- ep.SetSockOpt(&tcpip.LingerOption{
- Enabled: v.OnOff != 0,
- Timeout: time.Second * time.Duration(v.Linger)}))
+ ep.SocketOptions().SetLinger(tcpip.LingerOption{
+ Enabled: v.OnOff != 0,
+ Timeout: time.Second * time.Duration(v.Linger),
+ })
+ return nil
case linux.SO_DETACH_FILTER:
// optval is ignored.
@@ -2094,6 +2109,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
t.Kernel().EmitUnimplementedEvent(t)
+ case linux.IPV6_RECVORIGDSTADDR:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
+
case linux.IPV6_TCLASS:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2115,6 +2139,16 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
ep.SocketOptions().SetReceiveTClass(v != 0)
return nil
+ case linux.IPV6_RECVERR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ ep.SocketOptions().SetRecvError(v != 0)
+ return nil
case linux.IP6T_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIP6TReplace {
@@ -2303,6 +2337,17 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
ep.SocketOptions().SetReceiveTOS(v != 0)
return nil
+ case linux.IP_RECVERR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ ep.SocketOptions().SetRecvError(v != 0)
+ return nil
+
case linux.IP_PKTINFO:
if len(optVal) == 0 {
return nil
@@ -2325,6 +2370,18 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
ep.SocketOptions().SetHeaderIncluded(v != 0)
return nil
+ case linux.IP_RECVORIGDSTADDR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
+
case linux.IPT_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIPTReplace {
return syserr.ErrInvalidArgument
@@ -2360,10 +2417,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
linux.IP_NODEFRAG,
linux.IP_OPTIONS,
linux.IP_PASSSEC,
- linux.IP_RECVERR,
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
- linux.IP_RECVORIGDSTADDR,
linux.IP_RECVTTL,
linux.IP_RETOPTS,
linux.IP_TRANSPARENT,
@@ -2437,11 +2492,9 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_MULTICAST_IF,
linux.IPV6_MULTICAST_LOOP,
linux.IPV6_RECVDSTOPTS,
- linux.IPV6_RECVERR,
linux.IPV6_RECVFRAGSIZE,
linux.IPV6_RECVHOPLIMIT,
linux.IPV6_RECVHOPOPTS,
- linux.IPV6_RECVORIGDSTADDR,
linux.IPV6_RECVPATHMTU,
linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
@@ -2472,7 +2525,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
linux.IP_PKTINFO,
linux.IP_PKTOPTIONS,
linux.IP_MTU_DISCOVER,
- linux.IP_RECVERR,
linux.IP_RECVTTL,
linux.IP_RECVTOS,
linux.IP_MTU,
@@ -2701,7 +2753,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
// We need to peek beyond the first message.
dst = dst.DropFirst(n)
num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) {
- n, _, err := s.Endpoint.Peek(dsts)
+ n, err := s.Endpoint.Peek(dsts)
// TODO(b/78348848): Handle peek timestamp.
if err != nil {
return int64(n), syserr.TranslateNetstackError(err).ToError()
@@ -2745,15 +2797,19 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
return socket.ControlMessages{
- IP: tcpip.ControlMessages{
- HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
- Timestamp: s.readCM.Timestamp,
- HasTOS: s.readCM.HasTOS,
- TOS: s.readCM.TOS,
- HasTClass: s.readCM.HasTClass,
- TClass: s.readCM.TClass,
- HasIPPacketInfo: s.readCM.HasIPPacketInfo,
- PacketInfo: s.readCM.PacketInfo,
+ IP: socket.IPControlMessages{
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasInq: s.readCM.HasInq,
+ Inq: s.readCM.Inq,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
+ OriginalDstAddress: s.readCM.OriginalDstAddress,
+ SockErr: s.readCM.SockErr,
},
}
}
@@ -2770,9 +2826,66 @@ func (s *socketOpsCommon) updateTimestamp() {
}
}
+// dequeueErr is analogous to net/core/skbuff.c:sock_dequeue_err_skb().
+func (s *socketOpsCommon) dequeueErr() *tcpip.SockError {
+ so := s.Endpoint.SocketOptions()
+ err := so.DequeueErr()
+ if err == nil {
+ return nil
+ }
+
+ // Update socket error to reflect ICMP errors in queue.
+ if nextErr := so.PeekErr(); nextErr != nil && nextErr.ErrOrigin.IsICMPErr() {
+ so.SetLastError(nextErr.Err)
+ } else if err.ErrOrigin.IsICMPErr() {
+ so.SetLastError(nil)
+ }
+ return err
+}
+
+// addrFamilyFromNetProto returns the address family identifier for the given
+// network protocol.
+func addrFamilyFromNetProto(net tcpip.NetworkProtocolNumber) int {
+ switch net {
+ case header.IPv4ProtocolNumber:
+ return linux.AF_INET
+ case header.IPv6ProtocolNumber:
+ return linux.AF_INET6
+ default:
+ panic(fmt.Sprintf("invalid net proto for addr family inference: %d", net))
+ }
+}
+
+// recvErr handles MSG_ERRQUEUE for recvmsg(2).
+// This is analogous to net/ipv4/ip_sockglue.c:ip_recv_error().
+func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ sockErr := s.dequeueErr()
+ if sockErr == nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+
+ // The payload of the original packet that caused the error is passed as
+ // normal data via msg_iovec. -- recvmsg(2)
+ msgFlags := linux.MSG_ERRQUEUE
+ if int(dst.NumBytes()) < len(sockErr.Payload) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+ n, err := dst.CopyOut(t, sockErr.Payload)
+
+ // The original destination address of the datagram that caused the error is
+ // supplied via msg_name. -- recvmsg(2)
+ dstAddr, dstAddrLen := socket.ConvertAddress(addrFamilyFromNetProto(sockErr.NetProto), sockErr.Dst)
+ cmgs := socket.ControlMessages{IP: socket.NewIPControlMessages(s.family, tcpip.ControlMessages{SockErr: sockErr})}
+ return n, msgFlags, dstAddr, dstAddrLen, cmgs, syserr.FromError(err)
+}
+
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// tcpip.Endpoint.
func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+ if flags&linux.MSG_ERRQUEUE != 0 {
+ return s.recvErr(t, dst)
+ }
+
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 9049e8a21..97729dacc 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -44,7 +44,134 @@ import (
// control messages.
type ControlMessages struct {
Unix transport.ControlMessages
- IP tcpip.ControlMessages
+ IP IPControlMessages
+}
+
+// packetInfoToLinux converts IPPacketInfo from tcpip format to Linux format.
+func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo {
+ var p linux.ControlMessageIPPacketInfo
+ p.NIC = int32(packetInfo.NIC)
+ copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
+ copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+ return p
+}
+
+// errOriginToLinux maps tcpip socket origin to Linux socket origin constants.
+func errOriginToLinux(origin tcpip.SockErrOrigin) uint8 {
+ switch origin {
+ case tcpip.SockExtErrorOriginNone:
+ return linux.SO_EE_ORIGIN_NONE
+ case tcpip.SockExtErrorOriginLocal:
+ return linux.SO_EE_ORIGIN_LOCAL
+ case tcpip.SockExtErrorOriginICMP:
+ return linux.SO_EE_ORIGIN_ICMP
+ case tcpip.SockExtErrorOriginICMP6:
+ return linux.SO_EE_ORIGIN_ICMP6
+ default:
+ panic(fmt.Sprintf("unknown socket origin: %d", origin))
+ }
+}
+
+// sockErrCmsgToLinux converts SockError control message from tcpip format to
+// Linux format.
+func sockErrCmsgToLinux(sockErr *tcpip.SockError) linux.SockErrCMsg {
+ if sockErr == nil {
+ return nil
+ }
+
+ ee := linux.SockExtendedErr{
+ Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux().Number()),
+ Origin: errOriginToLinux(sockErr.ErrOrigin),
+ Type: sockErr.ErrType,
+ Code: sockErr.ErrCode,
+ Info: sockErr.ErrInfo,
+ }
+
+ switch sockErr.NetProto {
+ case header.IPv4ProtocolNumber:
+ errMsg := &linux.SockErrCMsgIPv4{SockExtendedErr: ee}
+ if len(sockErr.Offender.Addr) > 0 {
+ addr, _ := ConvertAddress(linux.AF_INET, sockErr.Offender)
+ errMsg.Offender = *addr.(*linux.SockAddrInet)
+ }
+ return errMsg
+ case header.IPv6ProtocolNumber:
+ errMsg := &linux.SockErrCMsgIPv6{SockExtendedErr: ee}
+ if len(sockErr.Offender.Addr) > 0 {
+ addr, _ := ConvertAddress(linux.AF_INET6, sockErr.Offender)
+ errMsg.Offender = *addr.(*linux.SockAddrInet6)
+ }
+ return errMsg
+ default:
+ panic(fmt.Sprintf("invalid net proto for creating SockErrCMsg: %d", sockErr.NetProto))
+ }
+}
+
+// NewIPControlMessages converts the tcpip ControlMessgaes (which does not
+// have Linux specific format) to Linux format.
+func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessages {
+ var orgDstAddr linux.SockAddr
+ if cmgs.HasOriginalDstAddress {
+ orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress)
+ }
+ return IPControlMessages{
+ HasTimestamp: cmgs.HasTimestamp,
+ Timestamp: cmgs.Timestamp,
+ HasInq: cmgs.HasInq,
+ Inq: cmgs.Inq,
+ HasTOS: cmgs.HasTOS,
+ TOS: cmgs.TOS,
+ HasTClass: cmgs.HasTClass,
+ TClass: cmgs.TClass,
+ HasIPPacketInfo: cmgs.HasIPPacketInfo,
+ PacketInfo: packetInfoToLinux(cmgs.PacketInfo),
+ OriginalDstAddress: orgDstAddr,
+ SockErr: sockErrCmsgToLinux(cmgs.SockErr),
+ }
+}
+
+// IPControlMessages contains socket control messages for IP sockets.
+// This can contain Linux specific structures unlike tcpip.ControlMessages.
+//
+// +stateify savable
+type IPControlMessages struct {
+ // HasTimestamp indicates whether Timestamp is valid/set.
+ HasTimestamp bool
+
+ // Timestamp is the time (in ns) that the last packet used to create
+ // the read data was received.
+ Timestamp int64
+
+ // HasInq indicates whether Inq is valid/set.
+ HasInq bool
+
+ // Inq is the number of bytes ready to be received.
+ Inq int32
+
+ // HasTOS indicates whether Tos is valid/set.
+ HasTOS bool
+
+ // TOS is the IPv4 type of service of the associated packet.
+ TOS uint8
+
+ // HasTClass indicates whether TClass is valid/set.
+ HasTClass bool
+
+ // TClass is the IPv6 traffic class of the associated packet.
+ TClass uint32
+
+ // HasIPPacketInfo indicates whether PacketInfo is set.
+ HasIPPacketInfo bool
+
+ // PacketInfo holds interface and address data on an incoming packet.
+ PacketInfo linux.ControlMessageIPPacketInfo
+
+ // OriginalDestinationAddress holds the original destination address
+ // and port of the incoming packet.
+ OriginalDstAddress linux.SockAddr
+
+ // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE).
+ SockErr linux.SockErrCMsg
}
// Release releases Unix domain socket credentials and rights.
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 0247e93fa..099a56281 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -746,9 +746,6 @@ type baseEndpoint struct {
// or may be used if the endpoint is connected.
path string
- // linger is used for SO_LINGER socket option.
- linger tcpip.LingerOption
-
// ops is used to get socket level options.
ops tcpip.SocketOptions
}
@@ -840,12 +837,6 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
// SetSockOpt sets a socket option.
func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
- switch v := opt.(type) {
- case *tcpip.LingerOption:
- e.Lock()
- e.linger = *v
- e.Unlock()
- }
return nil
}
@@ -922,17 +913,8 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
- switch o := opt.(type) {
- case *tcpip.LingerOption:
- e.Lock()
- *o = e.linger
- e.Unlock()
- return nil
-
- default:
- log.Warningf("Unsupported socket option: %T", opt)
- return tcpip.ErrUnknownProtocolOption
- }
+ log.Warningf("Unsupported socket option: %T", opt)
+ return tcpip.ErrUnknownProtocolOption
}
// LastError implements Endpoint.LastError.