diff options
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r-- | pkg/sentry/socket/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/control/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/control/control.go | 78 | ||||
-rw-r--r-- | pkg/sentry/socket/hostinet/socket.go | 160 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/socket.go | 12 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 705 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack_vfs2.go | 6 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/provider.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/provider_vfs2.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/stack.go | 30 | ||||
-rw-r--r-- | pkg/sentry/socket/socket.go | 249 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectioned.go | 27 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectionless.go | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/unix.go | 78 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 15 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix_vfs2.go | 2 |
16 files changed, 762 insertions, 607 deletions
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index a3f775d15..cc1f6bfcc 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -20,6 +20,7 @@ go_library( "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/tcpip", + "//pkg/tcpip/header", "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index ca16d0381..fb7c5dc61 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -23,7 +23,6 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", "//pkg/syserror", - "//pkg/tcpip", "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 70ccf77a7..b88cdca48 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -26,7 +26,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/usermem" ) @@ -344,21 +343,34 @@ func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte { } // PackIPPacketInfo packs an IP_PKTINFO socket control message. -func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte { - var p linux.ControlMessageIPPacketInfo - p.NIC = int32(packetInfo.NIC) - copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) - copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) - +func PackIPPacketInfo(t *kernel.Task, packetInfo *linux.ControlMessageIPPacketInfo, buf []byte) []byte { return putCmsgStruct( buf, linux.SOL_IP, linux.IP_PKTINFO, t.Arch().Width(), - p, + packetInfo, ) } +// PackOriginalDstAddress packs an IP_RECVORIGINALDSTADDR socket control message. +func PackOriginalDstAddress(t *kernel.Task, originalDstAddress linux.SockAddr, buf []byte) []byte { + var level uint32 + var optType uint32 + switch originalDstAddress.(type) { + case *linux.SockAddrInet: + level = linux.SOL_IP + optType = linux.IP_RECVORIGDSTADDR + case *linux.SockAddrInet6: + level = linux.SOL_IPV6 + optType = linux.IPV6_RECVORIGDSTADDR + default: + panic("invalid address type, must be an IP address for IP_RECVORIGINALDSTADDR cmsg") + } + return putCmsgStruct( + buf, level, optType, t.Arch().Width(), originalDstAddress) +} + // PackControlMessages packs control messages into the given buffer. // // We skip control messages specific to Unix domain sockets. @@ -384,7 +396,11 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt } if cmsgs.IP.HasIPPacketInfo { - buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf) + buf = PackIPPacketInfo(t, &cmsgs.IP.PacketInfo, buf) + } + + if cmsgs.IP.OriginalDstAddress != nil { + buf = PackOriginalDstAddress(t, cmsgs.IP.OriginalDstAddress, buf) } return buf @@ -416,17 +432,15 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { space += cmsgSpace(t, linux.SizeOfControlMessageTClass) } - return space -} + if cmsgs.IP.HasIPPacketInfo { + space += cmsgSpace(t, linux.SizeOfControlMessageIPPacketInfo) + } -// NewIPPacketInfo returns the IPPacketInfo struct. -func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo { - var p tcpip.IPPacketInfo - p.NIC = tcpip.NICID(packetInfo.NIC) - copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:]) - copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:]) + if cmsgs.IP.OriginalDstAddress != nil { + space += cmsgSpace(t, cmsgs.IP.OriginalDstAddress.SizeBytes()) + } - return p + return space } // Parse parses a raw socket control message into portable objects. @@ -489,6 +503,14 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con cmsgs.Unix.Credentials = scmCreds i += binary.AlignUp(length, width) + case linux.SO_TIMESTAMP: + if length < linux.SizeOfTimeval { + return socket.ControlMessages{}, syserror.EINVAL + } + cmsgs.IP.HasTimestamp = true + binary.Unmarshal(buf[i:i+linux.SizeOfTimeval], usermem.ByteOrder, &cmsgs.IP.Timestamp) + i += binary.AlignUp(length, width) + default: // Unknown message type. return socket.ControlMessages{}, syserror.EINVAL @@ -512,7 +534,16 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con var packetInfo linux.ControlMessageIPPacketInfo binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) - cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo) + cmsgs.IP.PacketInfo = packetInfo + i += binary.AlignUp(length, width) + + case linux.IP_RECVORIGDSTADDR: + var addr linux.SockAddrInet + if length < addr.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr) + cmsgs.IP.OriginalDstAddress = &addr i += binary.AlignUp(length, width) default: @@ -528,6 +559,15 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass) i += binary.AlignUp(length, width) + case linux.IPV6_RECVORIGDSTADDR: + var addr linux.SockAddrInet6 + if length < addr.SizeBytes() { + return socket.ControlMessages{}, syserror.EINVAL + } + binary.Unmarshal(buf[i:i+addr.SizeBytes()], usermem.ByteOrder, &addr) + cmsgs.IP.OriginalDstAddress = &addr + i += binary.AlignUp(length, width) + default: return socket.ControlMessages{}, syserror.EINVAL } diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 7d3c4a01c..1f220c343 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -331,17 +331,17 @@ func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO, linux.IP_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_IPV6: switch name { - case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_SOCKET: switch name { - case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: + case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP: optlen = sizeofInt32 case linux.SO_LINGER: optlen = syscall.SizeofLinger @@ -377,24 +377,24 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] switch level { case linux.SOL_IP: switch name { - case linux.IP_TOS, linux.IP_RECVTOS: + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_RECVORIGDSTADDR: optlen = sizeofInt32 case linux.IP_PKTINFO: optlen = linux.SizeOfControlMessageIPPacketInfo } case linux.SOL_IPV6: switch name { - case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY, linux.IPV6_RECVORIGDSTADDR: optlen = sizeofInt32 } case linux.SOL_SOCKET: switch name { - case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: + case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR, linux.SO_TIMESTAMP: optlen = sizeofInt32 } case linux.SOL_TCP: switch name { - case linux.TCP_NODELAY: + case linux.TCP_NODELAY, linux.TCP_INQ: optlen = sizeofInt32 } } @@ -416,6 +416,37 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] return nil } +func (s *socketOpsCommon) recvMsgFromHost(iovs []syscall.Iovec, flags int, senderRequested bool, controlLen uint64) (uint64, int, []byte, []byte, error) { + // We always do a non-blocking recv*(). + sysflags := flags | syscall.MSG_DONTWAIT + + msg := syscall.Msghdr{} + if len(iovs) > 0 { + msg.Iov = &iovs[0] + msg.Iovlen = uint64(len(iovs)) + } + var senderAddrBuf []byte + if senderRequested { + senderAddrBuf = make([]byte, sizeofSockaddr) + msg.Name = &senderAddrBuf[0] + msg.Namelen = uint32(sizeofSockaddr) + } + var controlBuf []byte + if controlLen > 0 { + if controlLen > maxControlLen { + controlLen = maxControlLen + } + controlBuf = make([]byte, controlLen) + msg.Control = &controlBuf[0] + msg.Controllen = controlLen + } + n, err := recvmsg(s.fd, &msg, sysflags) + if err != nil { + return 0 /* n */, 0 /* mFlags */, nil /* senderAddrBuf */, nil /* controlBuf */, err + } + return n, int(msg.Flags), senderAddrBuf[:msg.Namelen], controlBuf[:msg.Controllen], err +} + // RecvMsg implements socket.Socket.RecvMsg. func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { // Only allow known and safe flags. @@ -427,56 +458,36 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument } - var senderAddr linux.SockAddr var senderAddrBuf []byte - if senderRequested { - senderAddrBuf = make([]byte, sizeofSockaddr) - } - var controlBuf []byte var msgFlags int - - recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { - // Refuse to do anything if any part of dst.Addrs was unusable. - if uint64(dst.NumBytes()) != dsts.NumBytes() { - return 0, nil - } - if dsts.IsEmpty() { - return 0, nil - } - - // We always do a non-blocking recv*(). - sysflags := flags | syscall.MSG_DONTWAIT - - iovs := safemem.IovecsFromBlockSeq(dsts) - msg := syscall.Msghdr{ - Iov: &iovs[0], - Iovlen: uint64(len(iovs)), - } - if len(senderAddrBuf) != 0 { - msg.Name = &senderAddrBuf[0] - msg.Namelen = uint32(len(senderAddrBuf)) - } - if controlLen > 0 { - if controlLen > maxControlLen { - controlLen = maxControlLen + copyToDst := func() (int64, error) { + var n uint64 + var err error + if dst.NumBytes() == 0 { + // We want to make the recvmsg(2) call to the host even if dst is empty + // to fetch control messages, sender address or errors if any occur. + n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(nil, flags, senderRequested, controlLen) + return int64(n), err + } + + recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { + // Refuse to do anything if any part of dst.Addrs was unusable. + if uint64(dst.NumBytes()) != dsts.NumBytes() { + return 0, nil + } + if dsts.IsEmpty() { + return 0, nil } - controlBuf = make([]byte, controlLen) - msg.Control = &controlBuf[0] - msg.Controllen = controlLen - } - n, err := recvmsg(s.fd, &msg, sysflags) - if err != nil { - return 0, err - } - senderAddrBuf = senderAddrBuf[:msg.Namelen] - msgFlags = int(msg.Flags) - controlLen = uint64(msg.Controllen) - return n, nil - }) + + n, msgFlags, senderAddrBuf, controlBuf, err = s.recvMsgFromHost(safemem.IovecsFromBlockSeq(dsts), flags, senderRequested, controlLen) + return n, err + }) + return dst.CopyOutFrom(t, recvmsgToBlocks) + } var ch chan struct{} - n, err := dst.CopyOutFrom(t, recvmsgToBlocks) + n, err := copyToDst() if flags&syscall.MSG_DONTWAIT == 0 { for err == syserror.ErrWouldBlock { // We only expect blocking to come from the actual syscall, in which @@ -494,48 +505,75 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags s.EventRegister(&e, waiter.EventIn) defer s.EventUnregister(&e) } - n, err = dst.CopyOutFrom(t, recvmsgToBlocks) + n, err = copyToDst() } } if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } + var senderAddr linux.SockAddr if senderRequested { senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf) } - unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen]) + unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf) if err != nil { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) } + return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), parseUnixControlMessages(unixControlMessages), nil +} +func parseUnixControlMessages(unixControlMessages []unix.SocketControlMessage) socket.ControlMessages { controlMessages := socket.ControlMessages{} for _, unixCmsg := range unixControlMessages { switch unixCmsg.Header.Level { - case syscall.SOL_IP: + case linux.SOL_SOCKET: + switch unixCmsg.Header.Type { + case linux.SO_TIMESTAMP: + controlMessages.IP.HasTimestamp = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &controlMessages.IP.Timestamp) + } + + case linux.SOL_IP: switch unixCmsg.Header.Type { - case syscall.IP_TOS: + case linux.IP_TOS: controlMessages.IP.HasTOS = true binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS) - case syscall.IP_PKTINFO: + case linux.IP_PKTINFO: controlMessages.IP.HasIPPacketInfo = true var packetInfo linux.ControlMessageIPPacketInfo binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) - controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo) + controlMessages.IP.PacketInfo = packetInfo + + case linux.IP_RECVORIGDSTADDR: + var addr linux.SockAddrInet + binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) + controlMessages.IP.OriginalDstAddress = &addr } - case syscall.SOL_IPV6: + case linux.SOL_IPV6: switch unixCmsg.Header.Type { - case syscall.IPV6_TCLASS: + case linux.IPV6_TCLASS: controlMessages.IP.HasTClass = true binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass) + + case linux.IPV6_RECVORIGDSTADDR: + var addr linux.SockAddrInet6 + binary.Unmarshal(unixCmsg.Data[:addr.SizeBytes()], usermem.ByteOrder, &addr) + controlMessages.IP.OriginalDstAddress = &addr + } + + case linux.SOL_TCP: + switch unixCmsg.Header.Type { + case linux.TCP_INQ: + controlMessages.IP.HasInq = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageInq], usermem.ByteOrder, &controlMessages.IP.Inq) } } } - - return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil + return controlMessages } // SendMsg implements socket.Socket.SendMsg. diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 3baad098b..057f4d294 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -120,9 +120,6 @@ type socketOpsCommon struct { // fixed buffer but only consume this many bytes. sendBufferSize uint32 - // passcred indicates if this socket wants SCM credentials. - passcred bool - // filter indicates that this socket has a BPF filter "installed". // // TODO(gvisor.dev/issue/1119): We don't actually support filtering, @@ -201,10 +198,7 @@ func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { // Passcred implements transport.Credentialer.Passcred. func (s *socketOpsCommon) Passcred() bool { - s.mu.Lock() - passcred := s.passcred - s.mu.Unlock() - return passcred + return s.ep.SocketOptions().GetPassCred() } // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. @@ -419,9 +413,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] } passcred := usermem.ByteOrder.Uint32(opt) - s.mu.Lock() - s.passcred = passcred != 0 - s.mu.Unlock() + s.ep.SocketOptions().SetPassCred(passcred != 0) return nil case linux.SO_ATTACH_FILTER: diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 7d0ae15ca..23d5cab9c 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -84,69 +84,95 @@ var Metrics = tcpip.Stats{ MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."), DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."), ICMP: tcpip.ICMPStats{ - V4PacketsSent: tcpip.ICMPv4SentPacketStats{ - ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."), + V4: tcpip.ICMPv4Stats{ + PacketsSent: tcpip.ICMPv4SentPacketStats{ + ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ + Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."), + }, + Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."), + }, + PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ + ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ + Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."), + }, + Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."), }, - Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."), }, - V4PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ - ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."), + V6: tcpip.ICMPv6Stats{ + PacketsSent: tcpip.ICMPv6SentPacketStats{ + ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."), + }, + Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."), + }, + PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ + ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."), + }, + Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."), }, - Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."), }, - V6PacketsSent: tcpip.ICMPv6SentPacketStats{ - ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."), + }, + IGMP: tcpip.IGMPStats{ + PacketsSent: tcpip.IGMPSentPacketStats{ + IGMPPacketStats: tcpip.IGMPPacketStats{ + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_sent/membership_query", "Total number of IGMP Membership Query messages sent by netstack."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v1_membership_report", "Total number of IGMPv1 Membership Report messages sent by netstack."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_sent/v2_membership_report", "Total number of IGMPv2 Membership Report messages sent by netstack."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_sent/leave_group", "Total number of IGMP Leave Group messages sent by netstack."), }, - Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."), + Dropped: mustCreateMetric("/netstack/igmp/packets_sent/dropped", "Total number of IGMP packets dropped by netstack due to link layer errors."), }, - V6PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ - ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."), + PacketsReceived: tcpip.IGMPReceivedPacketStats{ + IGMPPacketStats: tcpip.IGMPPacketStats{ + MembershipQuery: mustCreateMetric("/netstack/igmp/packets_received/membership_query", "Total number of IGMP Membership Query messages received by netstack."), + V1MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v1_membership_report", "Total number of IGMPv1 Membership Report messages received by netstack."), + V2MembershipReport: mustCreateMetric("/netstack/igmp/packets_received/v2_membership_report", "Total number of IGMPv2 Membership Report messages received by netstack."), + LeaveGroup: mustCreateMetric("/netstack/igmp/packets_received/leave_group", "Total number of IGMP Leave Group messages received by netstack."), }, - Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."), + Invalid: mustCreateMetric("/netstack/igmp/packets_received/invalid", "Total number of IGMP packets received by netstack that could not be parsed."), + ChecksumErrors: mustCreateMetric("/netstack/igmp/packets_received/checksum_errors", "Total number of received IGMP packets with bad checksums."), + Unrecognized: mustCreateMetric("/netstack/igmp/packets_received/unrecognized", "Total number of unrecognized IGMP packets received by netstack."), }, }, IP: tcpip.IPStats{ @@ -209,18 +235,6 @@ const sizeOfInt32 int = 4 var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL) -// ntohs converts a 16-bit number from network byte order to host byte order. It -// assumes that the host is little endian. -func ntohs(v uint16) uint16 { - return v<<8 | v>>8 -} - -// htons converts a 16-bit number from host byte order to network byte order. It -// assumes that the host is little endian. -func htons(v uint16) uint16 { - return ntohs(v) -} - // commonEndpoint represents the intersection of a tcpip.Endpoint and a // transport.Endpoint. type commonEndpoint interface { @@ -240,10 +254,6 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error - // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and - // transport.Endpoint.SetSockOptBool. - SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error - // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and // transport.Endpoint.SetSockOptInt. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -252,18 +262,20 @@ type commonEndpoint interface { // transport.Endpoint.GetSockOpt. GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error - // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and - // transport.Endpoint.GetSockOpt. - GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) - // LastError implements tcpip.Endpoint.LastError. + // State returns a socket's lifecycle state. The returned value is + // protocol-specific and is primarily used for diagnostics. + State() uint32 + + // LastError implements tcpip.Endpoint.LastError and + // transport.Endpoint.LastError. LastError() *tcpip.Error - // SocketOptions implements tcpip.Endpoint.SocketOptions. + // SocketOptions implements tcpip.Endpoint.SocketOptions and + // transport.Endpoint.SocketOptions. SocketOptions() *tcpip.SocketOptions } @@ -308,7 +320,7 @@ type socketOpsCommon struct { readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. - readCM tcpip.ControlMessages + readCM socket.IPControlMessages sender tcpip.FullAddress linkPacketInfo tcpip.LinkPacketInfo @@ -332,9 +344,7 @@ type socketOpsCommon struct { // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + endpoint.SocketOptions().SetDelayOption(true) } dirent := socket.NewDirent(t, netstackDevice) @@ -363,88 +373,6 @@ func bytesToIPAddress(addr []byte) tcpip.Address { return tcpip.Address(addr) } -// AddressAndFamily reads an sockaddr struct from the given address and -// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, -// AF_INET6, and AF_PACKET addresses. -// -// AddressAndFamily returns an address and its family. -func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { - // Make sure we have at least 2 bytes for the address family. - if len(addr) < 2 { - return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument - } - - // Get the rest of the fields based on the address family. - switch family := usermem.ByteOrder.Uint16(addr); family { - case linux.AF_UNIX: - path := addr[2:] - if len(path) > linux.UnixPathMax { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - // Drop the terminating NUL (if one exists) and everything after - // it for filesystem (non-abstract) addresses. - if len(path) > 0 && path[0] != 0 { - if n := bytes.IndexByte(path[1:], 0); n >= 0 { - path = path[:n+1] - } - } - return tcpip.FullAddress{ - Addr: tcpip.Address(path), - }, family, nil - - case linux.AF_INET: - var a linux.SockAddrInet - if len(addr) < sockAddrInetSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) - - out := tcpip.FullAddress{ - Addr: bytesToIPAddress(a.Addr[:]), - Port: ntohs(a.Port), - } - return out, family, nil - - case linux.AF_INET6: - var a linux.SockAddrInet6 - if len(addr) < sockAddrInet6Size { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) - - out := tcpip.FullAddress{ - Addr: bytesToIPAddress(a.Addr[:]), - Port: ntohs(a.Port), - } - if isLinkLocal(out.Addr) { - out.NIC = tcpip.NICID(a.Scope_id) - } - return out, family, nil - - case linux.AF_PACKET: - var a linux.SockAddrLink - if len(addr) < sockAddrLinkSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) - if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - - // TODO(gvisor.dev/issue/173): Return protocol too. - return tcpip.FullAddress{ - NIC: tcpip.NICID(a.InterfaceIndex), - Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), - }, family, nil - - case linux.AF_UNSPEC: - return tcpip.FullAddress{}, family, nil - - default: - return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported - } -} - func (s *socketOpsCommon) isPacketBased() bool { return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW } @@ -480,7 +408,7 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error { } s.readView = v - s.readCM = cms + s.readCM = socket.NewIPControlMessages(s.family, cms) atomic.StoreUint32(&s.readViewHasData, 1) return nil @@ -500,11 +428,7 @@ func (s *socketOpsCommon) Release(ctx context.Context) { return } - var v tcpip.LingerOption - if err := s.Endpoint.GetSockOpt(&v); err != nil { - return - } - + v := s.Endpoint.SocketOptions().GetLinger() // The case for zero timeout is handled in tcp endpoint close function. // Close is blocked until either: // 1. The endpoint state is not in any of the states: FIN-WAIT1, @@ -721,11 +645,7 @@ func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error { return nil } if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 { - v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption) - if err != nil { - return syserr.TranslateNetstackError(err) - } - if !v { + if !s.Endpoint.SocketOptions().GetV6Only() { return nil } } @@ -749,7 +669,7 @@ func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, family, err := AddressAndFamily(sockaddr) + addr, family, err := socket.AddressAndFamily(sockaddr) if err != nil { return err } @@ -830,7 +750,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } } else { var err *syserr.Error - addr, family, err = AddressAndFamily(sockaddr) + addr, family, err = socket.AddressAndFamily(sockaddr) if err != nil { return err } @@ -921,7 +841,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = ConvertAddress(s.family, *peerAddr) + addr, addrLen = socket.ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -1005,7 +925,7 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in return getSockOptSocket(t, s, ep, family, skType, name, outLen) case linux.SOL_TCP: - return getSockOptTCP(t, ep, name, outLen) + return getSockOptTCP(t, s, ep, name, outLen) case linux.SOL_IPV6: return getSockOptIPv6(t, s, ep, name, outPtr, outLen) @@ -1041,7 +961,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // Get the last error and convert it. - err := ep.LastError() + err := ep.SocketOptions().GetLastError() if err == nil { optP := primitive.Int32(0) return &optP, nil @@ -1068,13 +988,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.PasscredOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetPassCred())) + return &v, nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1115,25 +1030,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReuseAddress())) + return &v, nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReusePortOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReusePort())) + return &v, nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1174,24 +1080,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetKeepAlive())) + return &v, nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { return nil, syserr.ErrInvalidArgument } - var v tcpip.LingerOption var linger linux.Linger - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + v := ep.SocketOptions().GetLinger() if v.Enabled { linger.OnOff = 1 @@ -1222,34 +1120,26 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - var v tcpip.OutOfBandInlineOption - if err := ep.GetSockOpt(&v); err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(v) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetOutOfBandInline())) + return &v, nil case linux.SO_NO_CHECK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.NoChecksumOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetNoChecksum())) + return &v, nil case linux.SO_ACCEPTCONN: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.AcceptConnOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + // This option is only viable for TCP endpoints. + var v bool + if _, skType, skProto := s.Type(); isTCPSocket(skType, skProto) { + v = tcp.EndpointState(ep.State()) == tcp.StateListen } vP := primitive.Int32(boolToInt32(v)) return &vP, nil @@ -1261,46 +1151,36 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { + if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) { + log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto) + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.DelayOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(!v)) - return &vP, nil + v := primitive.Int32(boolToInt32(!ep.SocketOptions().GetDelayOption())) + return &v, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.CorkOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetCorkOption())) + return &v, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.QuickAckOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetQuickAck())) + return &v, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { @@ -1474,19 +1354,24 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return nil, syserr.ErrUnknownProtocolOption + } + + family, skType, _ := s.Type() + if family != linux.AF_INET6 { + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.V6OnlyOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetV6Only())) + return &v, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1518,13 +1403,16 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass())) + return &v, nil + + case linux.IPV6_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil case linux.IP6T_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet6{})) { @@ -1536,7 +1424,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) + a, _ := socket.ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) return a.(*linux.SockAddrInet6), nil case linux.IP6T_SO_GET_INFO: @@ -1545,7 +1433,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1565,7 +1453,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1585,7 +1473,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1607,6 +1495,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name // getSockOptIP implements GetSockOpt when level is SOL_IP. func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1649,7 +1542,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) + a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) return &a.(*linux.SockAddrInet).Addr, nil @@ -1658,13 +1551,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetMulticastLoop())) + return &v, nil case linux.IP_TOS: // Length handling for parity with Linux. @@ -1688,26 +1576,32 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS())) + return &v, nil + + case linux.IP_PKTINFO: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceivePacketInfo())) + return &v, nil - case linux.IP_PKTINFO: + case linux.IP_HDRINCL: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded())) + return &v, nil + + case linux.IP_RECVORIGDSTADDR: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveOriginalDstAddress())) + return &v, nil case linux.SO_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet{})) { @@ -1719,7 +1613,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) + a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) return a.(*linux.SockAddrInet), nil case linux.IPT_SO_GET_INFO: @@ -1826,7 +1720,7 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int return setSockOptSocket(t, s, ep, name, optVal) case linux.SOL_TCP: - return setSockOptTCP(t, ep, name, optVal) + return setSockOptTCP(t, s, ep, name, optVal) case linux.SOL_IPV6: return setSockOptIPv6(t, s, ep, name, optVal) @@ -1876,7 +1770,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0)) + ep.SocketOptions().SetReuseAddress(v != 0) + return nil case linux.SO_REUSEPORT: if len(optVal) < sizeOfInt32 { @@ -1884,7 +1779,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0)) + ep.SocketOptions().SetReusePort(v != 0) + return nil case linux.SO_BINDTODEVICE: n := bytes.IndexByte(optVal, 0) @@ -1923,7 +1819,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0)) + ep.SocketOptions().SetPassCred(v != 0) + return nil case linux.SO_KEEPALIVE: if len(optVal) < sizeOfInt32 { @@ -1931,7 +1828,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0)) + ep.SocketOptions().SetKeepAlive(v != 0) + return nil case linux.SO_SNDTIMEO: if len(optVal) < linux.SizeOfTimeval { @@ -1970,8 +1868,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - opt := tcpip.OutOfBandInlineOption(v) - return syserr.TranslateNetstackError(ep.SetSockOpt(&opt)) + ep.SocketOptions().SetOutOfBandInline(v != 0) + return nil case linux.SO_NO_CHECK: if len(optVal) < sizeOfInt32 { @@ -1979,7 +1877,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0)) + ep.SocketOptions().SetNoChecksum(v != 0) + return nil case linux.SO_LINGER: if len(optVal) < linux.SizeOfLinger { @@ -1993,10 +1892,11 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam socket.SetSockOptEmitUnimplementedEvent(t, name) } - return syserr.TranslateNetstackError( - ep.SetSockOpt(&tcpip.LingerOption{ - Enabled: v.OnOff != 0, - Timeout: time.Second * time.Duration(v.Linger)})) + ep.SocketOptions().SetLinger(tcpip.LingerOption{ + Enabled: v.OnOff != 0, + Timeout: time.Second * time.Duration(v.Linger), + }) + return nil case linux.SO_DETACH_FILTER: // optval is ignored. @@ -2011,7 +1911,12 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } // setSockOptTCP implements SetSockOpt when level is SOL_TCP. -func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error { +func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) { + log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto) + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.TCP_NODELAY: if len(optVal) < sizeOfInt32 { @@ -2019,7 +1924,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0)) + ep.SocketOptions().SetDelayOption(v == 0) + return nil case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -2027,7 +1933,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0)) + ep.SocketOptions().SetCorkOption(v != 0) + return nil case linux.TCP_QUICKACK: if len(optVal) < sizeOfInt32 { @@ -2035,7 +1942,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0)) + ep.SocketOptions().SetQuickAck(v != 0) + return nil case linux.TCP_MAXSEG: if len(optVal) < sizeOfInt32 { @@ -2147,14 +2055,31 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * // setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6. func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return syserr.ErrUnknownProtocolOption + } + + family, skType, skProto := s.Type() + if family != linux.AF_INET6 { + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.IPV6_V6ONLY: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument } + if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial { + return syserr.ErrInvalidEndpointState + } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial { + return syserr.ErrInvalidEndpointState + } + v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0)) + ep.SocketOptions().SetV6Only(v != 0) + return nil case linux.IPV6_ADD_MEMBERSHIP, linux.IPV6_DROP_MEMBERSHIP, @@ -2174,6 +2099,15 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_RECVORIGDSTADDR: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil + case linux.IPV6_TCLASS: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -2193,7 +2127,8 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0)) + ep.SocketOptions().SetReceiveTClass(v != 0) + return nil case linux.IP6T_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIP6TReplace { @@ -2201,7 +2136,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return syserr.ErrProtocolNotAvailable } @@ -2276,6 +2211,11 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) { // setSockOptIP implements SetSockOpt when level is SOL_IP. func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.IP_MULTICAST_TTL: v, err := parseIntOrChar(optVal) @@ -2328,7 +2268,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.MulticastInterfaceOption{ NIC: tcpip.NICID(req.InterfaceIndex), - InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]), + InterfaceAddr: socket.BytesToIPAddress(req.InterfaceAddr[:]), })) case linux.IP_MULTICAST_LOOP: @@ -2337,7 +2277,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0)) + ep.SocketOptions().SetMulticastLoop(v != 0) + return nil case linux.MCAST_JOIN_GROUP: // FIXME(b/124219304): Implement MCAST_JOIN_GROUP. @@ -2373,7 +2314,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0)) + ep.SocketOptions().SetReceiveTOS(v != 0) + return nil case linux.IP_PKTINFO: if len(optVal) == 0 { @@ -2383,7 +2325,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + ep.SocketOptions().SetReceivePacketInfo(v != 0) + return nil case linux.IP_HDRINCL: if len(optVal) == 0 { @@ -2393,7 +2336,20 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) + ep.SocketOptions().SetHeaderIncluded(v != 0) + return nil + + case linux.IP_RECVORIGDSTADDR: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + ep.SocketOptions().SetReceiveOriginalDstAddress(v != 0) + return nil case linux.IPT_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIPTReplace { @@ -2433,7 +2389,6 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in linux.IP_RECVERR, linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, - linux.IP_RECVORIGDSTADDR, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -2511,7 +2466,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_RECVFRAGSIZE, linux.IPV6_RECVHOPLIMIT, linux.IPV6_RECVHOPOPTS, - linux.IPV6_RECVORIGDSTADDR, linux.IPV6_RECVPATHMTU, linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, @@ -2535,7 +2489,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { switch name { case linux.IP_TOS, linux.IP_TTL, - linux.IP_HDRINCL, linux.IP_OPTIONS, linux.IP_ROUTER_ALERT, linux.IP_RECVOPTS, @@ -2582,72 +2535,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { } } -// isLinkLocal determines if the given IPv6 address is link-local. This is the -// case when it has the fe80::/10 prefix. This check is used to determine when -// the NICID is relevant for a given IPv6 address. -func isLinkLocal(addr tcpip.Address) bool { - return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80 -} - -// ConvertAddress converts the given address to a native format. -func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { - switch family { - case linux.AF_UNIX: - var out linux.SockAddrUnix - out.Family = linux.AF_UNIX - l := len([]byte(addr.Addr)) - for i := 0; i < l; i++ { - out.Path[i] = int8(addr.Addr[i]) - } - - // Linux returns the used length of the address struct (including the - // null terminator) for filesystem paths. The Family field is 2 bytes. - // It is sometimes allowed to exclude the null terminator if the - // address length is the max. Abstract and empty paths always return - // the full exact length. - if l == 0 || out.Path[0] == 0 || l == len(out.Path) { - return &out, uint32(2 + l) - } - return &out, uint32(3 + l) - - case linux.AF_INET: - var out linux.SockAddrInet - copy(out.Addr[:], addr.Addr) - out.Family = linux.AF_INET - out.Port = htons(addr.Port) - return &out, uint32(sockAddrInetSize) - - case linux.AF_INET6: - var out linux.SockAddrInet6 - if len(addr.Addr) == header.IPv4AddressSize { - // Copy address in v4-mapped format. - copy(out.Addr[12:], addr.Addr) - out.Addr[10] = 0xff - out.Addr[11] = 0xff - } else { - copy(out.Addr[:], addr.Addr) - } - out.Family = linux.AF_INET6 - out.Port = htons(addr.Port) - if isLinkLocal(addr.Addr) { - out.Scope_id = uint32(addr.NIC) - } - return &out, uint32(sockAddrInet6Size) - - case linux.AF_PACKET: - // TODO(gvisor.dev/issue/173): Return protocol too. - var out linux.SockAddrLink - out.Family = linux.AF_PACKET - out.InterfaceIndex = int32(addr.NIC) - out.HardwareAddrLen = header.EthernetAddressSize - copy(out.HardwareAddr[:], addr.Addr) - return &out, uint32(sockAddrLinkSize) - - default: - return nil, 0 - } -} - // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { @@ -2656,7 +2543,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := ConvertAddress(s.family, addr) + a, l := socket.ConvertAddress(s.family, addr) return a, l, nil } @@ -2668,7 +2555,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := ConvertAddress(s.family, addr) + a, l := socket.ConvertAddress(s.family, addr) return a, l, nil } @@ -2686,7 +2573,7 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ // Always do at least one fetchReadView, even if the number of bytes to // read is 0. err = s.fetchReadView() - if err != nil { + if err != nil || len(s.readView) == 0 { break } if dst.NumBytes() == 0 { @@ -2709,15 +2596,20 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ } copied += n s.readView.TrimFront(n) - if len(s.readView) == 0 { - atomic.StoreUint32(&s.readViewHasData, 0) - } dst = dst.DropFirst(n) if e != nil { err = syserr.FromError(e) break } + // If we are done reading requested data then stop. + if dst.NumBytes() == 0 { + break + } + } + + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) } // If we managed to copy something, we must deliver it. @@ -2812,10 +2704,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq var addr linux.SockAddr var addrLen uint32 if isPacket && senderRequested { - addr, addrLen = ConvertAddress(s.family, s.sender) + addr, addrLen = socket.ConvertAddress(s.family, s.sender) switch v := addr.(type) { case *linux.SockAddrLink: - v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol)) v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) } } @@ -2833,7 +2725,7 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq // We need to peek beyond the first message. dst = dst.DropFirst(n) num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) { - n, _, err := s.Endpoint.Peek(dsts) + n, err := s.Endpoint.Peek(dsts) // TODO(b/78348848): Handle peek timestamp. if err != nil { return int64(n), syserr.TranslateNetstackError(err).ToError() @@ -2877,15 +2769,16 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq func (s *socketOpsCommon) controlMessages() socket.ControlMessages { return socket.ControlMessages{ - IP: tcpip.ControlMessages{ - HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, - Timestamp: s.readCM.Timestamp, - HasTOS: s.readCM.HasTOS, - TOS: s.readCM.TOS, - HasTClass: s.readCM.HasTClass, - TClass: s.readCM.TClass, - HasIPPacketInfo: s.readCM.HasIPPacketInfo, - PacketInfo: s.readCM.PacketInfo, + IP: socket.IPControlMessages{ + HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: s.readCM.Timestamp, + HasTOS: s.readCM.HasTOS, + TOS: s.readCM.TOS, + HasTClass: s.readCM.HasTClass, + TClass: s.readCM.TClass, + HasIPPacketInfo: s.readCM.HasIPPacketInfo, + PacketInfo: s.readCM.PacketInfo, + OriginalDstAddress: s.readCM.OriginalDstAddress, }, } } @@ -2980,7 +2873,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, family, err := AddressAndFamily(to) + addrBuf, family, err := socket.AddressAndFamily(to) if err != nil { return 0, err } @@ -3399,6 +3292,18 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 { return rv } +func isTCPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_STREAM && (skProto == 0 || skProto == syscall.IPPROTO_TCP) +} + +func isUDPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_DGRAM && (skProto == 0 || skProto == syscall.IPPROTO_UDP) +} + +func isICMPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_DGRAM && (skProto == syscall.IPPROTO_ICMP || skProto == syscall.IPPROTO_ICMPV6) +} + // State implements socket.Socket.State. State translates the internal state // returned by netstack to values defined by Linux. func (s *socketOpsCommon) State() uint32 { @@ -3408,7 +3313,7 @@ func (s *socketOpsCommon) State() uint32 { } switch { - case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP: + case isTCPSocket(s.skType, s.protocol): // TCP socket. switch tcp.EndpointState(s.Endpoint.State()) { case tcp.StateEstablished: @@ -3437,7 +3342,7 @@ func (s *socketOpsCommon) State() uint32 { // Internal or unknown state. return 0 } - case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP: + case isUDPSocket(s.skType, s.protocol): // UDP socket. switch udp.EndpointState(s.Endpoint.State()) { case udp.StateInitial, udp.StateBound, udp.StateClosed: @@ -3447,7 +3352,7 @@ func (s *socketOpsCommon) State() uint32 { default: return 0 } - case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6: + case isICMPSocket(s.skType, s.protocol): // TODO(b/112063468): Export states for ICMP sockets. case s.skType == linux.SOCK_RAW: // TODO(b/112063468): Export states for raw sockets. diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index b0d9e4d9e..b756bfca0 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -51,9 +51,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{}) // NewVFS2 creates a new endpoint socket. func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + endpoint.SocketOptions().SetDelayOption(true) } mnt := t.Kernel().SocketMount() @@ -191,7 +189,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addrLen uint32 if peerAddr != nil { // Get address of the peer and write it to peer slice. - addr, addrLen = ConvertAddress(s.family, *peerAddr) + addr, addrLen = socket.ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index ead3b2b79..c847ff1c7 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -158,7 +158,7 @@ func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol // protocol is passed in network byte order, but netstack wants it in // host order. - netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol))) wq := &waiter.Queue{} ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go index 2a01143f6..0af805246 100644 --- a/pkg/sentry/socket/netstack/provider_vfs2.go +++ b/pkg/sentry/socket/netstack/provider_vfs2.go @@ -102,7 +102,7 @@ func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, prot // protocol is passed in network byte order, but netstack wants it in // host order. - netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol))) wq := &waiter.Queue{} ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index fa9ac9059..cc0fadeb5 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -324,12 +324,12 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { 0, // Support Ip/FragCreates. } case *inet.StatSNMPICMP: - in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats - out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats + in := Metrics.ICMP.V4.PacketsReceived.ICMPv4PacketStats + out := Metrics.ICMP.V4.PacketsSent.ICMPv4PacketStats // TODO(gvisor.dev/issue/969) Support stubbed stats. *stats = inet.StatSNMPICMP{ 0, // Icmp/InMsgs. - Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors. + Metrics.ICMP.V4.PacketsSent.Dropped.Value(), // InErrors. 0, // Icmp/InCsumErrors. in.DstUnreachable.Value(), // InDestUnreachs. in.TimeExceeded.Value(), // InTimeExcds. @@ -343,18 +343,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { in.InfoRequest.Value(), // InAddrMasks. in.InfoReply.Value(), // InAddrMaskReps. 0, // Icmp/OutMsgs. - Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors. - out.DstUnreachable.Value(), // OutDestUnreachs. - out.TimeExceeded.Value(), // OutTimeExcds. - out.ParamProblem.Value(), // OutParmProbs. - out.SrcQuench.Value(), // OutSrcQuenchs. - out.Redirect.Value(), // OutRedirects. - out.Echo.Value(), // OutEchos. - out.EchoReply.Value(), // OutEchoReps. - out.Timestamp.Value(), // OutTimestamps. - out.TimestampReply.Value(), // OutTimestampReps. - out.InfoRequest.Value(), // OutAddrMasks. - out.InfoReply.Value(), // OutAddrMaskReps. + Metrics.ICMP.V4.PacketsReceived.Invalid.Value(), // OutErrors. + out.DstUnreachable.Value(), // OutDestUnreachs. + out.TimeExceeded.Value(), // OutTimeExcds. + out.ParamProblem.Value(), // OutParmProbs. + out.SrcQuench.Value(), // OutSrcQuenchs. + out.Redirect.Value(), // OutRedirects. + out.Echo.Value(), // OutEchos. + out.EchoReply.Value(), // OutEchoReps. + out.Timestamp.Value(), // OutTimestamps. + out.TimestampReply.Value(), // OutTimestampReps. + out.InfoRequest.Value(), // OutAddrMasks. + out.InfoReply.Value(), // OutAddrMaskReps. } case *inet.StatSNMPTCP: tcp := Metrics.TCP diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index fd31479e5..bcc426e33 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -18,6 +18,7 @@ package socket import ( + "bytes" "fmt" "sync/atomic" "syscall" @@ -35,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/usermem" ) @@ -42,7 +44,79 @@ import ( // control messages. type ControlMessages struct { Unix transport.ControlMessages - IP tcpip.ControlMessages + IP IPControlMessages +} + +// packetInfoToLinux converts IPPacketInfo from tcpip format to Linux format. +func packetInfoToLinux(packetInfo tcpip.IPPacketInfo) linux.ControlMessageIPPacketInfo { + var p linux.ControlMessageIPPacketInfo + p.NIC = int32(packetInfo.NIC) + copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) + copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) + return p +} + +// NewIPControlMessages converts the tcpip ControlMessgaes (which does not +// have Linux specific format) to Linux format. +func NewIPControlMessages(family int, cmgs tcpip.ControlMessages) IPControlMessages { + var orgDstAddr linux.SockAddr + if cmgs.HasOriginalDstAddress { + orgDstAddr, _ = ConvertAddress(family, cmgs.OriginalDstAddress) + } + return IPControlMessages{ + HasTimestamp: cmgs.HasTimestamp, + Timestamp: cmgs.Timestamp, + HasInq: cmgs.HasInq, + Inq: cmgs.Inq, + HasTOS: cmgs.HasTOS, + TOS: cmgs.TOS, + HasTClass: cmgs.HasTClass, + TClass: cmgs.TClass, + HasIPPacketInfo: cmgs.HasIPPacketInfo, + PacketInfo: packetInfoToLinux(cmgs.PacketInfo), + OriginalDstAddress: orgDstAddr, + } +} + +// IPControlMessages contains socket control messages for IP sockets. +// This can contain Linux specific structures unlike tcpip.ControlMessages. +// +// +stateify savable +type IPControlMessages struct { + // HasTimestamp indicates whether Timestamp is valid/set. + HasTimestamp bool + + // Timestamp is the time (in ns) that the last packet used to create + // the read data was received. + Timestamp int64 + + // HasInq indicates whether Inq is valid/set. + HasInq bool + + // Inq is the number of bytes ready to be received. + Inq int32 + + // HasTOS indicates whether Tos is valid/set. + HasTOS bool + + // TOS is the IPv4 type of service of the associated packet. + TOS uint8 + + // HasTClass indicates whether TClass is valid/set. + HasTClass bool + + // TClass is the IPv6 traffic class of the associated packet. + TClass uint32 + + // HasIPPacketInfo indicates whether PacketInfo is set. + HasIPPacketInfo bool + + // PacketInfo holds interface and address data on an incoming packet. + PacketInfo linux.ControlMessageIPPacketInfo + + // OriginalDestinationAddress holds the original destination address + // and port of the incoming packet. + OriginalDstAddress linux.SockAddr } // Release releases Unix domain socket credentials and rights. @@ -460,3 +534,176 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { panic(fmt.Sprintf("Unsupported socket family %v", family)) } } + +var sockAddrLinkSize = (&linux.SockAddrLink{}).SizeBytes() +var sockAddrInetSize = (&linux.SockAddrInet{}).SizeBytes() +var sockAddrInet6Size = (&linux.SockAddrInet6{}).SizeBytes() + +// Ntohs converts a 16-bit number from network byte order to host byte order. It +// assumes that the host is little endian. +func Ntohs(v uint16) uint16 { + return v<<8 | v>>8 +} + +// Htons converts a 16-bit number from host byte order to network byte order. It +// assumes that the host is little endian. +func Htons(v uint16) uint16 { + return Ntohs(v) +} + +// isLinkLocal determines if the given IPv6 address is link-local. This is the +// case when it has the fe80::/10 prefix. This check is used to determine when +// the NICID is relevant for a given IPv6 address. +func isLinkLocal(addr tcpip.Address) bool { + return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80 +} + +// ConvertAddress converts the given address to a native format. +func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { + switch family { + case linux.AF_UNIX: + var out linux.SockAddrUnix + out.Family = linux.AF_UNIX + l := len([]byte(addr.Addr)) + for i := 0; i < l; i++ { + out.Path[i] = int8(addr.Addr[i]) + } + + // Linux returns the used length of the address struct (including the + // null terminator) for filesystem paths. The Family field is 2 bytes. + // It is sometimes allowed to exclude the null terminator if the + // address length is the max. Abstract and empty paths always return + // the full exact length. + if l == 0 || out.Path[0] == 0 || l == len(out.Path) { + return &out, uint32(2 + l) + } + return &out, uint32(3 + l) + + case linux.AF_INET: + var out linux.SockAddrInet + copy(out.Addr[:], addr.Addr) + out.Family = linux.AF_INET + out.Port = Htons(addr.Port) + return &out, uint32(sockAddrInetSize) + + case linux.AF_INET6: + var out linux.SockAddrInet6 + if len(addr.Addr) == header.IPv4AddressSize { + // Copy address in v4-mapped format. + copy(out.Addr[12:], addr.Addr) + out.Addr[10] = 0xff + out.Addr[11] = 0xff + } else { + copy(out.Addr[:], addr.Addr) + } + out.Family = linux.AF_INET6 + out.Port = Htons(addr.Port) + if isLinkLocal(addr.Addr) { + out.Scope_id = uint32(addr.NIC) + } + return &out, uint32(sockAddrInet6Size) + + case linux.AF_PACKET: + // TODO(gvisor.dev/issue/173): Return protocol too. + var out linux.SockAddrLink + out.Family = linux.AF_PACKET + out.InterfaceIndex = int32(addr.NIC) + out.HardwareAddrLen = header.EthernetAddressSize + copy(out.HardwareAddr[:], addr.Addr) + return &out, uint32(sockAddrLinkSize) + + default: + return nil, 0 + } +} + +// BytesToIPAddress converts an IPv4 or IPv6 address from the user to the +// netstack representation taking any addresses into account. +func BytesToIPAddress(addr []byte) tcpip.Address { + if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) { + return "" + } + return tcpip.Address(addr) +} + +// AddressAndFamily reads an sockaddr struct from the given address and +// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, +// AF_INET6, and AF_PACKET addresses. +// +// AddressAndFamily returns an address and its family. +func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { + // Make sure we have at least 2 bytes for the address family. + if len(addr) < 2 { + return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument + } + + // Get the rest of the fields based on the address family. + switch family := usermem.ByteOrder.Uint16(addr); family { + case linux.AF_UNIX: + path := addr[2:] + if len(path) > linux.UnixPathMax { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + // Drop the terminating NUL (if one exists) and everything after + // it for filesystem (non-abstract) addresses. + if len(path) > 0 && path[0] != 0 { + if n := bytes.IndexByte(path[1:], 0); n >= 0 { + path = path[:n+1] + } + } + return tcpip.FullAddress{ + Addr: tcpip.Address(path), + }, family, nil + + case linux.AF_INET: + var a linux.SockAddrInet + if len(addr) < sockAddrInetSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) + + out := tcpip.FullAddress{ + Addr: BytesToIPAddress(a.Addr[:]), + Port: Ntohs(a.Port), + } + return out, family, nil + + case linux.AF_INET6: + var a linux.SockAddrInet6 + if len(addr) < sockAddrInet6Size { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) + + out := tcpip.FullAddress{ + Addr: BytesToIPAddress(a.Addr[:]), + Port: Ntohs(a.Port), + } + if isLinkLocal(out.Addr) { + out.NIC = tcpip.NICID(a.Scope_id) + } + return out, family, nil + + case linux.AF_PACKET: + var a linux.SockAddrLink + if len(addr) < sockAddrLinkSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) + if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/173): Return protocol too. + return tcpip.FullAddress{ + NIC: tcpip.NICID(a.InterfaceIndex), + Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + }, family, nil + + case linux.AF_UNSPEC: + return tcpip.FullAddress{}, family, nil + + default: + return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported + } +} diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 6d9e502bd..9f7aca305 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -118,28 +118,24 @@ var ( // NewConnectioned creates a new unbound connectionedEndpoint. func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint { - return &connectionedEndpoint{ + return newConnectioned(ctx, stype, uid) +} + +func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) *connectionedEndpoint { + ep := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), idGenerator: uid, stype: stype, } + ep.ops.InitHandler(ep) + return ep } // NewPair allocates a new pair of connected unix-domain connectionedEndpoints. func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { - a := &connectionedEndpoint{ - baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, - id: uid.UniqueID(), - idGenerator: uid, - stype: stype, - } - b := &connectionedEndpoint{ - baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, - id: uid.UniqueID(), - idGenerator: uid, - stype: stype, - } + a := newConnectioned(ctx, stype, uid) + b := newConnectioned(ctx, stype, uid) q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} q1.InitRefs() @@ -171,12 +167,14 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E // NewExternal creates a new externally backed Endpoint. It behaves like a // socketpair. func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { - return &connectionedEndpoint{ + ep := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected}, id: uid.UniqueID(), idGenerator: uid, stype: stype, } + ep.ops.InitHandler(ep) + return ep } // ID implements ConnectingEndpoint.ID. @@ -298,6 +296,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn idGenerator: e.idGenerator, stype: e.stype, } + ne.ops.InitHandler(ne) readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit} readQueue.InitRefs() diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 1406971bc..0813ad87d 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -44,6 +44,7 @@ func NewConnectionless(ctx context.Context) Endpoint { q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit} q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} + ep.ops.InitHandler(ep) return ep } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 18a50e9f8..099a56281 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -16,8 +16,6 @@ package transport import ( - "sync/atomic" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" @@ -180,10 +178,6 @@ type Endpoint interface { // SetSockOpt sets a socket option. SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error - // SetSockOptBool sets a socket option for simple cases when a value has - // the int type. - SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error - // SetSockOptInt sets a socket option for simple cases when a value has // the int type. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -191,10 +185,6 @@ type Endpoint interface { // GetSockOpt gets a socket option. GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error - // GetSockOptBool gets a socket option for simple cases when a return - // value has the int type. - GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) - // GetSockOptInt gets a socket option for simple cases when a return // value has the int type. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) @@ -203,10 +193,11 @@ type Endpoint interface { // procfs. State() uint32 - // LastError implements tcpip.Endpoint.LastError. + // LastError clears and returns the last error reported by the endpoint. LastError() *tcpip.Error - // SocketOptions implements tcpip.Endpoint.SocketOptions. + // SocketOptions returns the structure which contains all the socket + // level options. SocketOptions() *tcpip.SocketOptions } @@ -739,10 +730,7 @@ func (e *connectedEndpoint) CloseUnread() { // +stateify savable type baseEndpoint struct { *waiter.Queue - - // passcred specifies whether SCM_CREDENTIALS socket control messages are - // enabled on this endpoint. Must be accessed atomically. - passcred int32 + tcpip.DefaultSocketOptionsHandler // Mutex protects the below fields. sync.Mutex `state:"nosave"` @@ -758,9 +746,7 @@ type baseEndpoint struct { // or may be used if the endpoint is connected. path string - // linger is used for SO_LINGER socket option. - linger tcpip.LingerOption - + // ops is used to get socket level options. ops tcpip.SocketOptions } @@ -786,7 +772,7 @@ func (e *baseEndpoint) EventUnregister(we *waiter.Entry) { // Passcred implements Credentialer.Passcred. func (e *baseEndpoint) Passcred() bool { - return atomic.LoadInt32(&e.passcred) != 0 + return e.SocketOptions().GetPassCred() } // ConnectedPasscred implements Credentialer.ConnectedPasscred. @@ -796,14 +782,6 @@ func (e *baseEndpoint) ConnectedPasscred() bool { return e.connected != nil && e.connected.Passcred() } -func (e *baseEndpoint) setPasscred(pc bool) { - if pc { - atomic.StoreInt32(&e.passcred, 1) - } else { - atomic.StoreInt32(&e.passcred, 0) - } -} - // Connected implements ConnectingEndpoint.Connected. func (e *baseEndpoint) Connected() bool { return e.receiver != nil && e.connected != nil @@ -859,23 +837,6 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess // SetSockOpt sets a socket option. func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { - switch v := opt.(type) { - case *tcpip.LingerOption: - e.Lock() - e.linger = *v - e.Unlock() - } - return nil -} - -func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - case tcpip.PasscredOption: - e.setPasscred(v) - case tcpip.ReuseAddressOption: - default: - log.Warningf("Unsupported socket option: %d", opt) - } return nil } @@ -889,20 +850,6 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } -func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: - return false, nil - - case tcpip.PasscredOption: - return e.Passcred(), nil - - default: - log.Warningf("Unsupported socket option: %d", opt) - return false, tcpip.ErrUnknownProtocolOption - } -} - func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -966,17 +913,8 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { - switch o := opt.(type) { - case *tcpip.LingerOption: - e.Lock() - *o = e.linger - e.Unlock() - return nil - - default: - log.Warningf("Unsupported socket option: %T", opt) - return tcpip.ErrUnknownProtocolOption - } + log.Warningf("Unsupported socket option: %T", opt) + return tcpip.ErrUnknownProtocolOption } // LastError implements Endpoint.LastError. diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 3e520d2ee..c59297c80 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -115,9 +115,6 @@ type socketOpsCommon struct { // bound, they cannot be modified. abstractName string abstractNamespace *kernel.AbstractSocketNamespace - - // ops is used to get socket level options. - ops tcpip.SocketOptions } func (s *socketOpsCommon) isPacket() bool { @@ -139,7 +136,7 @@ func (s *socketOpsCommon) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, family, err := netstack.AddressAndFamily(sockaddr) + addr, family, err := socket.AddressAndFamily(sockaddr) if err != nil { if err == syserr.ErrAddressFamilyNotSupported { err = syserr.ErrInvalidArgument @@ -172,7 +169,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) + a, l := socket.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } @@ -184,7 +181,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) + a, l := socket.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } @@ -258,7 +255,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) + addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -650,7 +647,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { - from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { @@ -685,7 +682,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil { - from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index eaf0b0d26..27f705bb2 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -172,7 +172,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) + addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ |