summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/control/control.go39
-rw-r--r--pkg/sentry/socket/hostinet/socket.go137
-rw-r--r--pkg/sentry/socket/netstack/netstack.go110
-rw-r--r--pkg/sentry/socket/socket.go55
-rw-r--r--pkg/sentry/socket/socket_state_autogen.go3
5 files changed, 275 insertions, 69 deletions
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index b88cdca48..ff6b71802 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -371,6 +371,17 @@ func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, b
buf, level, optType, t.Arch().Width(), originalDstAddress)
}
+// PackSockExtendedErr packs an IP*_RECVERR socket control message.
+func PackSockExtendedErr(t *kernel.Task, sockErr linux.SockErrCMsg, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ sockErr.CMsgLevel(),
+ sockErr.CMsgType(),
+ t.Arch().Width(),
+ sockErr,
+ )
+}
+
// PackControlMessages packs control messages into the given buffer.
//
// We skip control messages specific to Unix domain sockets.
@@ -403,6 +414,10 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf)
}
+ if cmsgs.IP.SockErr != nil {
+ buf = PackSockExtendedErr(t, cmsgs.IP.SockErr, buf)
+ }
+
return buf
}
@@ -440,6 +455,10 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes())
}
+ if cmsgs.IP.SockErr != nil {
+ space += cmsgSpace(t, cmsgs.IP.SockErr.SizeBytes())
+ }
+
return space
}
@@ -546,6 +565,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
cmsgs.IP.OriginalDstAddress = &addr
i += binary.AlignUp(length, width)
+ case linux.IP_RECVERR:
+ var errCmsg linux.SockErrCMsgIPv4
+ if length < errCmsg.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
+ cmsgs.IP.SockErr = &errCmsg
+ i += binary.AlignUp(length, width)
+
default:
return socket.ControlMessages{}, syserror.EINVAL
}
@@ -568,6 +597,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
cmsgs.IP.OriginalDstAddress = &addr
i += binary.AlignUp(length, width)
+ case linux.IPV6_RECVERR:
+ var errCmsg linux.SockErrCMsgIPv6
+ if length < errCmsg.SizeBytes() {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ errCmsg.UnmarshalBytes(buf[i : i+errCmsg.SizeBytes()])
+ cmsgs.IP.SockErr = &errCmsg
+ i += binary.AlignUp(length, width)
+
default:
return socket.ControlMessages{}, syserror.EINVAL
}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index be418df2e..5b868216d 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -331,12 +331,12 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR:
optlen = sizeofInt32
}
case linux.SOL_IPV6:
switch name {
- case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_SOCKET:
@@ -377,14 +377,14 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR, linux.IP_RECVERR:
optlen = sizeofInt32
case linux.IP_PKTINFO:
optlen = linux.SizeOfControlMessageIPPacketInfo
}
case linux.SOL_IPV6:
switch name {
- case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_RECVERR, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR:
optlen = sizeofInt32
}
case linux.SOL_SOCKET:
@@ -416,68 +416,76 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []
return nil
}
-// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
- // Only allow known and safe flags.
- //
- // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary
- // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the
- // Socket interface's dependence on netstack.
- if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 {
- return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
- }
+func (s *socketOpsCommon) recvMsgFromHost(iovs []syscall.Iovec, flags int, senderRequested bool, controlLen uint64) (uint64, int, []byte, []byte, error) {
+ // We always do a non-blocking recv*().
+ sysflags := flags | syscall.MSG_DONTWAIT
- var senderAddr linux.SockAddr
+ msg := syscall.Msghdr{}
+ if len(iovs) > 0 {
+ msg.Iov = &iovs[0]
+ msg.Iovlen = uint64(len(iovs))
+ }
var senderAddrBuf []byte
if senderRequested {
senderAddrBuf = make([]byte, sizeofSockaddr)
+ msg.Name = &senderAddrBuf[0]
+ msg.Namelen = uint32(sizeofSockaddr)
}
-
var controlBuf []byte
- var msgFlags int
-
- recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
- // Refuse to do anything if any part of dst.Addrs was unusable.
- if uint64(dst.NumBytes()) != dsts.NumBytes() {
- return 0, nil
- }
- if dsts.IsEmpty() {
- return 0, nil
+ if controlLen > 0 {
+ if controlLen > maxControlLen {
+ controlLen = maxControlLen
}
+ controlBuf = make([]byte, controlLen)
+ msg.Control = &controlBuf[0]
+ msg.Controllen = controlLen
+ }
+ n, err := recvmsg(s.fd, &msg, sysflags)
+ if err != nil {
+ return 0 /* n */, 0 /* mFlags */, nil /* senderAddrBuf */, nil /* controlBuf */, err
+ }
+ return n, int(msg.Flags), senderAddrBuf[:msg.Namelen], controlBuf[:msg.Controllen], err
+}
- // We always do a non-blocking recv*().
- sysflags := flags | syscall.MSG_DONTWAIT
+// RecvMsg implements socket.Socket.RecvMsg.
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ // Only allow known and safe flags.
+ if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC|syscall.MSG_ERRQUEUE) != 0 {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
+ }
- iovs := safemem.IovecsFromBlockSeq(dsts)
- msg := syscall.Msghdr{
- Iov: &iovs[0],
- Iovlen: uint64(len(iovs)),
- }
- if len(senderAddrBuf) != 0 {
- msg.Name = &senderAddrBuf[0]
- msg.Namelen = uint32(len(senderAddrBuf))
- }
- if controlLen > 0 {
- if controlLen > maxControlLen {
- controlLen = maxControlLen
+ var senderAddrBuf []byte
+ var controlBuf []byte
+ var msgFlags int
+ copyToDst := func() (int64, error) {
+ var n uint64
+ var err error
+ if dst.NumBytes() == 0 {
+ // We want to make the recvmsg(2) call to the host even if dst is empty
+ // to fetch control messages, sender address or errors if any occur.
+ n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(nil, flags, senderRequested, controlLen)
+ return int64(n), err
+ }
+
+ recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of dst.Addrs was unusable.
+ if uint64(dst.NumBytes()) != dsts.NumBytes() {
+ return 0, nil
}
- controlBuf = make([]byte, controlLen)
- msg.Control = &controlBuf[0]
- msg.Controllen = controlLen
- }
- n, err := recvmsg(s.fd, &msg, sysflags)
- if err != nil {
- return 0, err
- }
- senderAddrBuf = senderAddrBuf[:msg.Namelen]
- msgFlags = int(msg.Flags)
- controlLen = uint64(msg.Controllen)
- return n, nil
- })
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+
+ n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(safemem.IovecsFromBlockSeq(dsts), flags, senderRequested, controlLen)
+ return n, err
+ })
+ return dst.CopyOutFrom(t, recvmsgToBlocks)
+ }
var ch chan struct{}
- n, err := dst.CopyOutFrom(t, recvmsgToBlocks)
- if flags&syscall.MSG_DONTWAIT == 0 {
+ n, err := copyToDst()
+ // recv*(MSG_ERRQUEUE) never blocks, even without MSG_DONTWAIT.
+ if flags&(syscall.MSG_DONTWAIT|syscall.MSG_ERRQUEUE) == 0 {
for err == syserror.ErrWouldBlock {
// We only expect blocking to come from the actual syscall, in which
// case it can't have returned any data.
@@ -494,22 +502,26 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
s.EventRegister(&e, waiter.EventIn)
defer s.EventUnregister(&e)
}
- n, err = dst.CopyOutFrom(t, recvmsgToBlocks)
+ n, err = copyToDst()
}
}
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
+ var senderAddr linux.SockAddr
if senderRequested {
senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf)
}
- unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen])
+ unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf)
if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
}
+ return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), parseUnixControlMessages(unixControlMessages), nil
+}
+func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) socket.ControlMessages {
controlMessages := socket.ControlMessages{}
for _, unixCmsg := range unixControlMessages {
switch unixCmsg.Header.Level {
@@ -536,6 +548,11 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var addr linux.SockAddrInet
binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr)
controlMessages.IP.OriginalDstAddress = &addr
+
+ case syscall.IP_RECVERR:
+ var errCmsg linux.SockErrCMsgIPv4
+ errCmsg.UnmarshalBytes(unixCmsg.Data)
+ controlMessages.IP.SockErr = &errCmsg
}
case linux.SOL_IPV6:
@@ -548,6 +565,11 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var addr linux.SockAddrInet6
binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr)
controlMessages.IP.OriginalDstAddress = &addr
+
+ case syscall.IPV6_RECVERR:
+ var errCmsg linux.SockErrCMsgIPv6
+ errCmsg.UnmarshalBytes(unixCmsg.Data)
+ controlMessages.IP.SockErr = &errCmsg
}
case linux.SOL_TCP:
@@ -558,8 +580,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
}
}
-
- return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil
+ return controlMessages
}
// SendMsg implements socket.Socket.SendMsg.
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 23d5cab9c..3f587638f 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -1042,10 +1042,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return &v, nil
case linux.SO_BINDTODEVICE:
- var v tcpip.BindToDeviceOption
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ v := ep.SocketOptions().GetBindToDevice()
if v == 0 {
var b primitive.ByteSlice
return &b, nil
@@ -1405,6 +1402,13 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass()))
return &v, nil
+ case linux.IPV6_RECVERR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError()))
+ return &v, nil
case linux.IPV6_RECVORIGDSTADDR:
if outLen < sizeOfInt32 {
@@ -1579,6 +1583,14 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS()))
return &v, nil
+ case linux.IP_RECVERR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError()))
+ return &v, nil
+
case linux.IP_PKTINFO:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
@@ -1789,8 +1801,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
name := string(optVal[:n])
if name == "" {
- v := tcpip.BindToDeviceOption(0)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&v))
+ return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(0))
}
s := t.NetworkContext()
if s == nil {
@@ -1798,8 +1809,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
for nicID, nic := range s.Interfaces() {
if nic.Name == name {
- v := tcpip.BindToDeviceOption(nicID)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&v))
+ return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(nicID))
}
}
return syserr.ErrUnknownDevice
@@ -2129,6 +2139,16 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
ep.SocketOptions().SetReceiveTClass(v != 0)
return nil
+ case linux.IPV6_RECVERR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ ep.SocketOptions().SetRecvError(v != 0)
+ return nil
case linux.IP6T_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIP6TReplace {
@@ -2317,6 +2337,17 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
ep.SocketOptions().SetReceiveTOS(v != 0)
return nil
+ case linux.IP_RECVERR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ ep.SocketOptions().SetRecvError(v != 0)
+ return nil
+
case linux.IP_PKTINFO:
if len(optVal) == 0 {
return nil
@@ -2386,7 +2417,6 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
linux.IP_NODEFRAG,
linux.IP_OPTIONS,
linux.IP_PASSSEC,
- linux.IP_RECVERR,
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
linux.IP_RECVTTL,
@@ -2462,7 +2492,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_MULTICAST_IF,
linux.IPV6_MULTICAST_LOOP,
linux.IPV6_RECVDSTOPTS,
- linux.IPV6_RECVERR,
linux.IPV6_RECVFRAGSIZE,
linux.IPV6_RECVHOPLIMIT,
linux.IPV6_RECVHOPOPTS,
@@ -2496,7 +2525,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
linux.IP_PKTINFO,
linux.IP_PKTOPTIONS,
linux.IP_MTU_DISCOVER,
- linux.IP_RECVERR,
linux.IP_RECVTTL,
linux.IP_RECVTOS,
linux.IP_MTU,
@@ -2772,6 +2800,8 @@ func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
IP: socket.IPControlMessages{
HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
Timestamp: s.readCM.Timestamp,
+ HasInq: s.readCM.HasInq,
+ Inq: s.readCM.Inq,
HasTOS: s.readCM.HasTOS,
TOS: s.readCM.TOS,
HasTClass: s.readCM.HasTClass,
@@ -2779,6 +2809,7 @@ func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
HasIPPacketInfo: s.readCM.HasIPPacketInfo,
PacketInfo: s.readCM.PacketInfo,
OriginalDstAddress: s.readCM.OriginalDstAddress,
+ SockErr: s.readCM.SockErr,
},
}
}
@@ -2795,9 +2826,66 @@ func (s *socketOpsCommon) updateTimestamp() {
}
}
+// dequeueErr is analogous to net/core/skbuff.c:sock_dequeue_err_skb().
+func (s *socketOpsCommon) dequeueErr() *tcpip.SockError {
+ so := s.Endpoint.SocketOptions()
+ err := so.DequeueErr()
+ if err == nil {
+ return nil
+ }
+
+ // Update socket error to reflect ICMP errors in queue.
+ if nextErr := so.PeekErr(); nextErr != nil && nextErr.ErrOrigin.IsICMPErr() {
+ so.SetLastError(nextErr.Err)
+ } else if err.ErrOrigin.IsICMPErr() {
+ so.SetLastError(nil)
+ }
+ return err
+}
+
+// addrFamilyFromNetProto returns the address family identifier for the given
+// network protocol.
+func addrFamilyFromNetProto(net tcpip.NetworkProtocolNumber) int {
+ switch net {
+ case header.IPv4ProtocolNumber:
+ return linux.AF_INET
+ case header.IPv6ProtocolNumber:
+ return linux.AF_INET6
+ default:
+ panic(fmt.Sprintf("invalid net proto for addr family inference: %d", net))
+ }
+}
+
+// recvErr handles MSG_ERRQUEUE for recvmsg(2).
+// This is analogous to net/ipv4/ip_sockglue.c:ip_recv_error().
+func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ sockErr := s.dequeueErr()
+ if sockErr == nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+
+ // The payload of the original packet that caused the error is passed as
+ // normal data via msg_iovec. -- recvmsg(2)
+ msgFlags := linux.MSG_ERRQUEUE
+ if int(dst.NumBytes()) < len(sockErr.Payload) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+ n, err := dst.CopyOut(t, sockErr.Payload)
+
+ // The original destination address of the datagram that caused the error is
+ // supplied via msg_name. -- recvmsg(2)
+ dstAddr, dstAddrLen := socket.ConvertAddress(addrFamilyFromNetProto(sockErr.NetProto), sockErr.Dst)
+ cmgs := socket.ControlMessages{IP: socket.NewIPControlMessages(s.family, tcpip.ControlMessages{SockErr: sockErr})}
+ return n, msgFlags, dstAddr, dstAddrLen, cmgs, syserr.FromError(err)
+}
+
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// tcpip.Endpoint.
func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+ if flags&linux.MSG_ERRQUEUE != 0 {
+ return s.recvErr(t, dst)
+ }
+
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index bcc426e33..97729dacc 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -56,6 +56,57 @@ func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPack
return p
}
+// errOriginToLinux maps tcpip socket origin to Linux socket origin constants.
+func errOriginToLinux(origin tcpip.SockErrOrigin) uint8 {
+ switch origin {
+ case tcpip.SockExtErrorOriginNone:
+ return linux.SO_EE_ORIGIN_NONE
+ case tcpip.SockExtErrorOriginLocal:
+ return linux.SO_EE_ORIGIN_LOCAL
+ case tcpip.SockExtErrorOriginICMP:
+ return linux.SO_EE_ORIGIN_ICMP
+ case tcpip.SockExtErrorOriginICMP6:
+ return linux.SO_EE_ORIGIN_ICMP6
+ default:
+ panic(fmt.Sprintf("unknown socket origin: %d", origin))
+ }
+}
+
+// sockErrCmsgToLinux converts SockError control message from tcpip format to
+// Linux format.
+func sockErrCmsgToLinux(sockErr *tcpip.SockError) linux.SockErrCMsg {
+ if sockErr == nil {
+ return nil
+ }
+
+ ee := linux.SockExtendedErr{
+ Errno: uint32(syserr.TranslateNetstackError(sockErr.Err).ToLinux().Number()),
+ Origin: errOriginToLinux(sockErr.ErrOrigin),
+ Type: sockErr.ErrType,
+ Code: sockErr.ErrCode,
+ Info: sockErr.ErrInfo,
+ }
+
+ switch sockErr.NetProto {
+ case header.IPv4ProtocolNumber:
+ errMsg := &linux.SockErrCMsgIPv4{SockExtendedErr: ee}
+ if len(sockErr.Offender.Addr) > 0 {
+ addr, _ := ConvertAddress(linux.AF_INET, sockErr.Offender)
+ errMsg.Offender = *addr.(*linux.SockAddrInet)
+ }
+ return errMsg
+ case header.IPv6ProtocolNumber:
+ errMsg := &linux.SockErrCMsgIPv6{SockExtendedErr: ee}
+ if len(sockErr.Offender.Addr) > 0 {
+ addr, _ := ConvertAddress(linux.AF_INET6, sockErr.Offender)
+ errMsg.Offender = *addr.(*linux.SockAddrInet6)
+ }
+ return errMsg
+ default:
+ panic(fmt.Sprintf("invalid net proto for creating SockErrCMsg: %d", sockErr.NetProto))
+ }
+}
+
// NewIPControlMessages converts the tcpip ControlMessgaes (which does not
// have Linux specific format) to Linux format.
func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessages {
@@ -75,6 +126,7 @@ func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessa
HasIPPacketInfo: cmgs.HasIPPacketInfo,
PacketInfo: packetInfoToLinux(cmgs.PacketInfo),
OriginalDstAddress: orgDstAddr,
+ SockErr: sockErrCmsgToLinux(cmgs.SockErr),
}
}
@@ -117,6 +169,9 @@ type IPControlMessages struct {
// OriginalDestinationAddress holds the original destination address
// and port of the incoming packet.
OriginalDstAddress linux.SockAddr
+
+ // SockErr is the dequeued socket error on recvmsg(MSG_ERRQUEUE).
+ SockErr linux.SockErrCMsg
}
// Release releases Unix domain socket credentials and rights.
diff --git a/pkg/sentry/socket/socket_state_autogen.go b/pkg/sentry/socket/socket_state_autogen.go
index 4f854f99f..970698808 100644
--- a/pkg/sentry/socket/socket_state_autogen.go
+++ b/pkg/sentry/socket/socket_state_autogen.go
@@ -23,6 +23,7 @@ func (i *IPControlMessages) StateFields() []string {
"HasIPPacketInfo",
"PacketInfo",
"OriginalDstAddress",
+ "SockErr",
}
}
@@ -41,6 +42,7 @@ func (i *IPControlMessages) StateSave(stateSinkObject state.Sink) {
stateSinkObject.Save(8, &i.HasIPPacketInfo)
stateSinkObject.Save(9, &i.PacketInfo)
stateSinkObject.Save(10, &i.OriginalDstAddress)
+ stateSinkObject.Save(11, &i.SockErr)
}
func (i *IPControlMessages) afterLoad() {}
@@ -57,6 +59,7 @@ func (i *IPControlMessages) StateLoad(stateSourceObject state.Source) {
stateSourceObject.Load(8, &i.HasIPPacketInfo)
stateSourceObject.Load(9, &i.PacketInfo)
stateSourceObject.Load(10, &i.OriginalDstAddress)
+ stateSourceObject.Load(11, &i.SockErr)
}
func (to *SendReceiveTimeout) StateTypeName() string {