summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/netstack
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/netstack')
-rw-r--r--pkg/sentry/socket/netstack/netstack.go209
1 files changed, 161 insertions, 48 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 3cc0d4f0f..3f587638f 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -320,7 +320,7 @@ type socketOpsCommon struct {
readView buffer.View
// readCM holds control message information for the last packet read
// from Endpoint.
- readCM tcpip.ControlMessages
+ readCM socket.IPControlMessages
sender tcpip.FullAddress
linkPacketInfo tcpip.LinkPacketInfo
@@ -408,7 +408,7 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error {
}
s.readView = v
- s.readCM = cms
+ s.readCM = socket.NewIPControlMessages(s.family, cms)
atomic.StoreUint32(&s.readViewHasData, 1)
return nil
@@ -428,11 +428,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) {
return
}
- var v tcpip.LingerOption
- if err := s.Endpoint.GetSockOpt(&v); err != nil {
- return
- }
-
+ v := s.Endpoint.SocketOptions().GetLinger()
// The case for zero timeout is handled in tcp endpoint close function.
// Close is blocked until either:
// 1. The endpoint state is not in any of the states: FIN-WAIT1,
@@ -965,7 +961,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// Get the last error and convert it.
- err := ep.LastError()
+ err := ep.SocketOptions().GetLastError()
if err == nil {
optP := primitive.Int32(0)
return &optP, nil
@@ -1046,10 +1042,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return &v, nil
case linux.SO_BINDTODEVICE:
- var v tcpip.BindToDeviceOption
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ v := ep.SocketOptions().GetBindToDevice()
if v == 0 {
var b primitive.ByteSlice
return &b, nil
@@ -1092,11 +1085,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.LingerOption
var linger linux.Linger
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
+ v := ep.SocketOptions().GetLinger()
if v.Enabled {
linger.OnOff = 1
@@ -1127,13 +1117,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.OutOfBandInlineOption
- if err := ep.GetSockOpt(&v); err != nil {
- return nil, syserr.TranslateNetstackError(err)
- }
-
- vP := primitive.Int32(v)
- return &vP, nil
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetOutOfBandInline()))
+ return &v, nil
case linux.SO_NO_CHECK:
if outLen < sizeOfInt32 {
@@ -1417,6 +1402,21 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass()))
return &v, nil
+ case linux.IPV6_RECVERR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError()))
+ return &v, nil
+
+ case linux.IPV6_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
case linux.IP6T_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet6{})) {
@@ -1583,6 +1583,14 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS()))
return &v, nil
+ case linux.IP_RECVERR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetRecvError()))
+ return &v, nil
+
case linux.IP_PKTINFO:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
@@ -1599,6 +1607,14 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded()))
return &v, nil
+ case linux.IP_RECVORIGDSTADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress()))
+ return &v, nil
+
case linux.SO_ORIGINAL_DST:
if outLen < int(binary.Size(linux.SockAddrInet{})) {
return nil, syserr.ErrInvalidArgument
@@ -1785,8 +1801,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
name := string(optVal[:n])
if name == "" {
- v := tcpip.BindToDeviceOption(0)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&v))
+ return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(0))
}
s := t.NetworkContext()
if s == nil {
@@ -1794,8 +1809,7 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
for nicID, nic := range s.Interfaces() {
if nic.Name == name {
- v := tcpip.BindToDeviceOption(nicID)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&v))
+ return syserr.TranslateNetstackError(ep.SocketOptions().SetBindToDevice(nicID))
}
}
return syserr.ErrUnknownDevice
@@ -1864,8 +1878,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- opt := tcpip.OutOfBandInlineOption(v)
- return syserr.TranslateNetstackError(ep.SetSockOpt(&opt))
+ ep.SocketOptions().SetOutOfBandInline(v != 0)
+ return nil
case linux.SO_NO_CHECK:
if len(optVal) < sizeOfInt32 {
@@ -1888,10 +1902,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- return syserr.TranslateNetstackError(
- ep.SetSockOpt(&tcpip.LingerOption{
- Enabled: v.OnOff != 0,
- Timeout: time.Second * time.Duration(v.Linger)}))
+ ep.SocketOptions().SetLinger(tcpip.LingerOption{
+ Enabled: v.OnOff != 0,
+ Timeout: time.Second * time.Duration(v.Linger),
+ })
+ return nil
case linux.SO_DETACH_FILTER:
// optval is ignored.
@@ -2094,6 +2109,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
t.Kernel().EmitUnimplementedEvent(t)
+ case linux.IPV6_RECVORIGDSTADDR:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
+
case linux.IPV6_TCLASS:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -2115,6 +2139,16 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name
ep.SocketOptions().SetReceiveTClass(v != 0)
return nil
+ case linux.IPV6_RECVERR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ ep.SocketOptions().SetRecvError(v != 0)
+ return nil
case linux.IP6T_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIP6TReplace {
@@ -2303,6 +2337,17 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
ep.SocketOptions().SetReceiveTOS(v != 0)
return nil
+ case linux.IP_RECVERR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ ep.SocketOptions().SetRecvError(v != 0)
+ return nil
+
case linux.IP_PKTINFO:
if len(optVal) == 0 {
return nil
@@ -2325,6 +2370,18 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
ep.SocketOptions().SetHeaderIncluded(v != 0)
return nil
+ case linux.IP_RECVORIGDSTADDR:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0)
+ return nil
+
case linux.IPT_SO_SET_REPLACE:
if len(optVal) < linux.SizeOfIPTReplace {
return syserr.ErrInvalidArgument
@@ -2360,10 +2417,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
linux.IP_NODEFRAG,
linux.IP_OPTIONS,
linux.IP_PASSSEC,
- linux.IP_RECVERR,
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
- linux.IP_RECVORIGDSTADDR,
linux.IP_RECVTTL,
linux.IP_RETOPTS,
linux.IP_TRANSPARENT,
@@ -2437,11 +2492,9 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_MULTICAST_IF,
linux.IPV6_MULTICAST_LOOP,
linux.IPV6_RECVDSTOPTS,
- linux.IPV6_RECVERR,
linux.IPV6_RECVFRAGSIZE,
linux.IPV6_RECVHOPLIMIT,
linux.IPV6_RECVHOPOPTS,
- linux.IPV6_RECVORIGDSTADDR,
linux.IPV6_RECVPATHMTU,
linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
@@ -2472,7 +2525,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) {
linux.IP_PKTINFO,
linux.IP_PKTOPTIONS,
linux.IP_MTU_DISCOVER,
- linux.IP_RECVERR,
linux.IP_RECVTTL,
linux.IP_RECVTOS,
linux.IP_MTU,
@@ -2701,7 +2753,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
// We need to peek beyond the first message.
dst = dst.DropFirst(n)
num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) {
- n, _, err := s.Endpoint.Peek(dsts)
+ n, err := s.Endpoint.Peek(dsts)
// TODO(b/78348848): Handle peek timestamp.
if err != nil {
return int64(n), syserr.TranslateNetstackError(err).ToError()
@@ -2745,15 +2797,19 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq
func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
return socket.ControlMessages{
- IP: tcpip.ControlMessages{
- HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
- Timestamp: s.readCM.Timestamp,
- HasTOS: s.readCM.HasTOS,
- TOS: s.readCM.TOS,
- HasTClass: s.readCM.HasTClass,
- TClass: s.readCM.TClass,
- HasIPPacketInfo: s.readCM.HasIPPacketInfo,
- PacketInfo: s.readCM.PacketInfo,
+ IP: socket.IPControlMessages{
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasInq: s.readCM.HasInq,
+ Inq: s.readCM.Inq,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
+ OriginalDstAddress: s.readCM.OriginalDstAddress,
+ SockErr: s.readCM.SockErr,
},
}
}
@@ -2770,9 +2826,66 @@ func (s *socketOpsCommon) updateTimestamp() {
}
}
+// dequeueErr is analogous to net/core/skbuff.c:sock_dequeue_err_skb().
+func (s *socketOpsCommon) dequeueErr() *tcpip.SockError {
+ so := s.Endpoint.SocketOptions()
+ err := so.DequeueErr()
+ if err == nil {
+ return nil
+ }
+
+ // Update socket error to reflect ICMP errors in queue.
+ if nextErr := so.PeekErr(); nextErr != nil && nextErr.ErrOrigin.IsICMPErr() {
+ so.SetLastError(nextErr.Err)
+ } else if err.ErrOrigin.IsICMPErr() {
+ so.SetLastError(nil)
+ }
+ return err
+}
+
+// addrFamilyFromNetProto returns the address family identifier for the given
+// network protocol.
+func addrFamilyFromNetProto(net tcpip.NetworkProtocolNumber) int {
+ switch net {
+ case header.IPv4ProtocolNumber:
+ return linux.AF_INET
+ case header.IPv6ProtocolNumber:
+ return linux.AF_INET6
+ default:
+ panic(fmt.Sprintf("invalid net proto for addr family inference: %d", net))
+ }
+}
+
+// recvErr handles MSG_ERRQUEUE for recvmsg(2).
+// This is analogous to net/ipv4/ip_sockglue.c:ip_recv_error().
+func (s *socketOpsCommon) recvErr(t *kernel.Task, dst usermem.IOSequence) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ sockErr := s.dequeueErr()
+ if sockErr == nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+
+ // The payload of the original packet that caused the error is passed as
+ // normal data via msg_iovec. -- recvmsg(2)
+ msgFlags := linux.MSG_ERRQUEUE
+ if int(dst.NumBytes()) < len(sockErr.Payload) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+ n, err := dst.CopyOut(t, sockErr.Payload)
+
+ // The original destination address of the datagram that caused the error is
+ // supplied via msg_name. -- recvmsg(2)
+ dstAddr, dstAddrLen := socket.ConvertAddress(addrFamilyFromNetProto(sockErr.NetProto), sockErr.Dst)
+ cmgs := socket.ControlMessages{IP: socket.NewIPControlMessages(s.family, tcpip.ControlMessages{SockErr: sockErr})}
+ return n, msgFlags, dstAddr, dstAddrLen, cmgs, syserr.FromError(err)
+}
+
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// tcpip.Endpoint.
func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+ if flags&linux.MSG_ERRQUEUE != 0 {
+ return s.recvErr(t, dst)
+ }
+
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0