diff options
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r-- | pkg/sentry/socket/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/socket.go | 12 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 594 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack_vfs2.go | 6 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/provider.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/provider_vfs2.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/stack.go | 30 | ||||
-rw-r--r-- | pkg/sentry/socket/socket.go | 175 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectioned.go | 27 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/connectionless.go | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/unix.go | 56 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 15 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix_vfs2.go | 2 |
13 files changed, 450 insertions, 473 deletions
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index a3f775d15..cc1f6bfcc 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -20,6 +20,7 @@ go_library( "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/tcpip", + "//pkg/tcpip/header", "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 3baad098b..057f4d294 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -120,9 +120,6 @@ type socketOpsCommon struct { // fixed buffer but only consume this many bytes. sendBufferSize uint32 - // passcred indicates if this socket wants SCM credentials. - passcred bool - // filter indicates that this socket has a BPF filter "installed". // // TODO(gvisor.dev/issue/1119): We don't actually support filtering, @@ -201,10 +198,7 @@ func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { // Passcred implements transport.Credentialer.Passcred. func (s *socketOpsCommon) Passcred() bool { - s.mu.Lock() - passcred := s.passcred - s.mu.Unlock() - return passcred + return s.ep.SocketOptions().GetPassCred() } // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. @@ -419,9 +413,7 @@ func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt [] } passcred := usermem.ByteOrder.Uint32(opt) - s.mu.Lock() - s.passcred = passcred != 0 - s.mu.Unlock() + s.ep.SocketOptions().SetPassCred(passcred != 0) return nil case linux.SO_ATTACH_FILTER: diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 7d0ae15ca..5e9ab97ad 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -84,69 +84,73 @@ var Metrics = tcpip.Stats{ MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."), DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."), ICMP: tcpip.ICMPStats{ - V4PacketsSent: tcpip.ICMPv4SentPacketStats{ - ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."), + V4: tcpip.ICMPv4Stats{ + PacketsSent: tcpip.ICMPv4SentPacketStats{ + ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ + Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."), + }, + Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."), }, - Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."), - }, - V4PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ - ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."), + PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ + ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ + Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."), + }, + Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."), }, - Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."), }, - V6PacketsSent: tcpip.ICMPv6SentPacketStats{ - ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."), + V6: tcpip.ICMPv6Stats{ + PacketsSent: tcpip.ICMPv6SentPacketStats{ + ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."), + }, + Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."), }, - Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."), - }, - V6PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ - ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."), + PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ + ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."), + }, + Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."), }, - Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."), }, }, IP: tcpip.IPStats{ @@ -209,18 +213,6 @@ const sizeOfInt32 int = 4 var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL) -// ntohs converts a 16-bit number from network byte order to host byte order. It -// assumes that the host is little endian. -func ntohs(v uint16) uint16 { - return v<<8 | v>>8 -} - -// htons converts a 16-bit number from host byte order to network byte order. It -// assumes that the host is little endian. -func htons(v uint16) uint16 { - return ntohs(v) -} - // commonEndpoint represents the intersection of a tcpip.Endpoint and a // transport.Endpoint. type commonEndpoint interface { @@ -240,10 +232,6 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error - // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and - // transport.Endpoint.SetSockOptBool. - SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error - // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and // transport.Endpoint.SetSockOptInt. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -252,18 +240,20 @@ type commonEndpoint interface { // transport.Endpoint.GetSockOpt. GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error - // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and - // transport.Endpoint.GetSockOpt. - GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) - // LastError implements tcpip.Endpoint.LastError. + // State returns a socket's lifecycle state. The returned value is + // protocol-specific and is primarily used for diagnostics. + State() uint32 + + // LastError implements tcpip.Endpoint.LastError and + // transport.Endpoint.LastError. LastError() *tcpip.Error - // SocketOptions implements tcpip.Endpoint.SocketOptions. + // SocketOptions implements tcpip.Endpoint.SocketOptions and + // transport.Endpoint.SocketOptions. SocketOptions() *tcpip.SocketOptions } @@ -332,9 +322,7 @@ type socketOpsCommon struct { // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + endpoint.SocketOptions().SetDelayOption(true) } dirent := socket.NewDirent(t, netstackDevice) @@ -363,88 +351,6 @@ func bytesToIPAddress(addr []byte) tcpip.Address { return tcpip.Address(addr) } -// AddressAndFamily reads an sockaddr struct from the given address and -// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, -// AF_INET6, and AF_PACKET addresses. -// -// AddressAndFamily returns an address and its family. -func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { - // Make sure we have at least 2 bytes for the address family. - if len(addr) < 2 { - return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument - } - - // Get the rest of the fields based on the address family. - switch family := usermem.ByteOrder.Uint16(addr); family { - case linux.AF_UNIX: - path := addr[2:] - if len(path) > linux.UnixPathMax { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - // Drop the terminating NUL (if one exists) and everything after - // it for filesystem (non-abstract) addresses. - if len(path) > 0 && path[0] != 0 { - if n := bytes.IndexByte(path[1:], 0); n >= 0 { - path = path[:n+1] - } - } - return tcpip.FullAddress{ - Addr: tcpip.Address(path), - }, family, nil - - case linux.AF_INET: - var a linux.SockAddrInet - if len(addr) < sockAddrInetSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) - - out := tcpip.FullAddress{ - Addr: bytesToIPAddress(a.Addr[:]), - Port: ntohs(a.Port), - } - return out, family, nil - - case linux.AF_INET6: - var a linux.SockAddrInet6 - if len(addr) < sockAddrInet6Size { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) - - out := tcpip.FullAddress{ - Addr: bytesToIPAddress(a.Addr[:]), - Port: ntohs(a.Port), - } - if isLinkLocal(out.Addr) { - out.NIC = tcpip.NICID(a.Scope_id) - } - return out, family, nil - - case linux.AF_PACKET: - var a linux.SockAddrLink - if len(addr) < sockAddrLinkSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) - if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - - // TODO(gvisor.dev/issue/173): Return protocol too. - return tcpip.FullAddress{ - NIC: tcpip.NICID(a.InterfaceIndex), - Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), - }, family, nil - - case linux.AF_UNSPEC: - return tcpip.FullAddress{}, family, nil - - default: - return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported - } -} - func (s *socketOpsCommon) isPacketBased() bool { return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW } @@ -721,11 +627,7 @@ func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error { return nil } if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 { - v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption) - if err != nil { - return syserr.TranslateNetstackError(err) - } - if !v { + if !s.Endpoint.SocketOptions().GetV6Only() { return nil } } @@ -749,7 +651,7 @@ func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, family, err := AddressAndFamily(sockaddr) + addr, family, err := socket.AddressAndFamily(sockaddr) if err != nil { return err } @@ -830,7 +732,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } } else { var err *syserr.Error - addr, family, err = AddressAndFamily(sockaddr) + addr, family, err = socket.AddressAndFamily(sockaddr) if err != nil { return err } @@ -921,7 +823,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = ConvertAddress(s.family, *peerAddr) + addr, addrLen = socket.ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -1005,7 +907,7 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in return getSockOptSocket(t, s, ep, family, skType, name, outLen) case linux.SOL_TCP: - return getSockOptTCP(t, ep, name, outLen) + return getSockOptTCP(t, s, ep, name, outLen) case linux.SOL_IPV6: return getSockOptIPv6(t, s, ep, name, outPtr, outLen) @@ -1068,13 +970,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.PasscredOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetPassCred())) + return &v, nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -1115,25 +1012,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReuseAddress())) + return &v, nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReusePortOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReusePort())) + return &v, nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption @@ -1174,13 +1062,8 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetKeepAlive())) + return &v, nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { @@ -1235,21 +1118,18 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.NoChecksumOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetNoChecksum())) + return &v, nil case linux.SO_ACCEPTCONN: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.AcceptConnOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + // This option is only viable for TCP endpoints. + var v bool + if _, skType, skProto := s.Type(); isTCPSocket(skType, skProto) { + v = tcp.EndpointState(ep.State()) == tcp.StateListen } vP := primitive.Int32(boolToInt32(v)) return &vP, nil @@ -1261,46 +1141,36 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { + if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) { + log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto) + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.DelayOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(!v)) - return &vP, nil + v := primitive.Int32(boolToInt32(!ep.SocketOptions().GetDelayOption())) + return &v, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.CorkOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetCorkOption())) + return &v, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.QuickAckOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetQuickAck())) + return &v, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { @@ -1474,19 +1344,24 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return nil, syserr.ErrUnknownProtocolOption + } + + family, skType, _ := s.Type() + if family != linux.AF_INET6 { + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.V6OnlyOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetV6Only())) + return &v, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1518,13 +1393,8 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass())) + return &v, nil case linux.IP6T_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet6{})) { @@ -1536,7 +1406,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) + a, _ := socket.ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) return a.(*linux.SockAddrInet6), nil case linux.IP6T_SO_GET_INFO: @@ -1545,7 +1415,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1565,7 +1435,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1585,7 +1455,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1607,6 +1477,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name // getSockOptIP implements GetSockOpt when level is SOL_IP. func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1649,7 +1524,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) + a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) return &a.(*linux.SockAddrInet).Addr, nil @@ -1658,13 +1533,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetMulticastLoop())) + return &v, nil case linux.IP_TOS: // Length handling for parity with Linux. @@ -1688,26 +1558,24 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS())) + return &v, nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceivePacketInfo())) + return &v, nil + + case linux.IP_HDRINCL: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded())) + return &v, nil case linux.SO_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet{})) { @@ -1719,7 +1587,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) + a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) return a.(*linux.SockAddrInet), nil case linux.IPT_SO_GET_INFO: @@ -1826,7 +1694,7 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int return setSockOptSocket(t, s, ep, name, optVal) case linux.SOL_TCP: - return setSockOptTCP(t, ep, name, optVal) + return setSockOptTCP(t, s, ep, name, optVal) case linux.SOL_IPV6: return setSockOptIPv6(t, s, ep, name, optVal) @@ -1876,7 +1744,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0)) + ep.SocketOptions().SetReuseAddress(v != 0) + return nil case linux.SO_REUSEPORT: if len(optVal) < sizeOfInt32 { @@ -1884,7 +1753,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0)) + ep.SocketOptions().SetReusePort(v != 0) + return nil case linux.SO_BINDTODEVICE: n := bytes.IndexByte(optVal, 0) @@ -1923,7 +1793,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0)) + ep.SocketOptions().SetPassCred(v != 0) + return nil case linux.SO_KEEPALIVE: if len(optVal) < sizeOfInt32 { @@ -1931,7 +1802,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0)) + ep.SocketOptions().SetKeepAlive(v != 0) + return nil case linux.SO_SNDTIMEO: if len(optVal) < linux.SizeOfTimeval { @@ -1979,7 +1851,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0)) + ep.SocketOptions().SetNoChecksum(v != 0) + return nil case linux.SO_LINGER: if len(optVal) < linux.SizeOfLinger { @@ -2011,7 +1884,12 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } // setSockOptTCP implements SetSockOpt when level is SOL_TCP. -func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error { +func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) { + log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto) + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.TCP_NODELAY: if len(optVal) < sizeOfInt32 { @@ -2019,7 +1897,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0)) + ep.SocketOptions().SetDelayOption(v == 0) + return nil case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -2027,7 +1906,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0)) + ep.SocketOptions().SetCorkOption(v != 0) + return nil case linux.TCP_QUICKACK: if len(optVal) < sizeOfInt32 { @@ -2035,7 +1915,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0)) + ep.SocketOptions().SetQuickAck(v != 0) + return nil case linux.TCP_MAXSEG: if len(optVal) < sizeOfInt32 { @@ -2147,14 +2028,31 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * // setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6. func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return syserr.ErrUnknownProtocolOption + } + + family, skType, skProto := s.Type() + if family != linux.AF_INET6 { + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.IPV6_V6ONLY: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument } + if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial { + return syserr.ErrInvalidEndpointState + } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial { + return syserr.ErrInvalidEndpointState + } + v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0)) + ep.SocketOptions().SetV6Only(v != 0) + return nil case linux.IPV6_ADD_MEMBERSHIP, linux.IPV6_DROP_MEMBERSHIP, @@ -2193,7 +2091,8 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0)) + ep.SocketOptions().SetReceiveTClass(v != 0) + return nil case linux.IP6T_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIP6TReplace { @@ -2201,7 +2100,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return syserr.ErrProtocolNotAvailable } @@ -2276,6 +2175,11 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) { // setSockOptIP implements SetSockOpt when level is SOL_IP. func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.IP_MULTICAST_TTL: v, err := parseIntOrChar(optVal) @@ -2328,7 +2232,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.MulticastInterfaceOption{ NIC: tcpip.NICID(req.InterfaceIndex), - InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]), + InterfaceAddr: socket.BytesToIPAddress(req.InterfaceAddr[:]), })) case linux.IP_MULTICAST_LOOP: @@ -2337,7 +2241,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0)) + ep.SocketOptions().SetMulticastLoop(v != 0) + return nil case linux.MCAST_JOIN_GROUP: // FIXME(b/124219304): Implement MCAST_JOIN_GROUP. @@ -2373,7 +2278,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0)) + ep.SocketOptions().SetReceiveTOS(v != 0) + return nil case linux.IP_PKTINFO: if len(optVal) == 0 { @@ -2383,7 +2289,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + ep.SocketOptions().SetReceivePacketInfo(v != 0) + return nil case linux.IP_HDRINCL: if len(optVal) == 0 { @@ -2393,7 +2300,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) + ep.SocketOptions().SetHeaderIncluded(v != 0) + return nil case linux.IPT_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIPTReplace { @@ -2535,7 +2443,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { switch name { case linux.IP_TOS, linux.IP_TTL, - linux.IP_HDRINCL, linux.IP_OPTIONS, linux.IP_ROUTER_ALERT, linux.IP_RECVOPTS, @@ -2582,72 +2489,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { } } -// isLinkLocal determines if the given IPv6 address is link-local. This is the -// case when it has the fe80::/10 prefix. This check is used to determine when -// the NICID is relevant for a given IPv6 address. -func isLinkLocal(addr tcpip.Address) bool { - return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80 -} - -// ConvertAddress converts the given address to a native format. -func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { - switch family { - case linux.AF_UNIX: - var out linux.SockAddrUnix - out.Family = linux.AF_UNIX - l := len([]byte(addr.Addr)) - for i := 0; i < l; i++ { - out.Path[i] = int8(addr.Addr[i]) - } - - // Linux returns the used length of the address struct (including the - // null terminator) for filesystem paths. The Family field is 2 bytes. - // It is sometimes allowed to exclude the null terminator if the - // address length is the max. Abstract and empty paths always return - // the full exact length. - if l == 0 || out.Path[0] == 0 || l == len(out.Path) { - return &out, uint32(2 + l) - } - return &out, uint32(3 + l) - - case linux.AF_INET: - var out linux.SockAddrInet - copy(out.Addr[:], addr.Addr) - out.Family = linux.AF_INET - out.Port = htons(addr.Port) - return &out, uint32(sockAddrInetSize) - - case linux.AF_INET6: - var out linux.SockAddrInet6 - if len(addr.Addr) == header.IPv4AddressSize { - // Copy address in v4-mapped format. - copy(out.Addr[12:], addr.Addr) - out.Addr[10] = 0xff - out.Addr[11] = 0xff - } else { - copy(out.Addr[:], addr.Addr) - } - out.Family = linux.AF_INET6 - out.Port = htons(addr.Port) - if isLinkLocal(addr.Addr) { - out.Scope_id = uint32(addr.NIC) - } - return &out, uint32(sockAddrInet6Size) - - case linux.AF_PACKET: - // TODO(gvisor.dev/issue/173): Return protocol too. - var out linux.SockAddrLink - out.Family = linux.AF_PACKET - out.InterfaceIndex = int32(addr.NIC) - out.HardwareAddrLen = header.EthernetAddressSize - copy(out.HardwareAddr[:], addr.Addr) - return &out, uint32(sockAddrLinkSize) - - default: - return nil, 0 - } -} - // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { @@ -2656,7 +2497,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := ConvertAddress(s.family, addr) + a, l := socket.ConvertAddress(s.family, addr) return a, l, nil } @@ -2668,7 +2509,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := ConvertAddress(s.family, addr) + a, l := socket.ConvertAddress(s.family, addr) return a, l, nil } @@ -2686,7 +2527,7 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ // Always do at least one fetchReadView, even if the number of bytes to // read is 0. err = s.fetchReadView() - if err != nil { + if err != nil || len(s.readView) == 0 { break } if dst.NumBytes() == 0 { @@ -2709,15 +2550,20 @@ func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequ } copied += n s.readView.TrimFront(n) - if len(s.readView) == 0 { - atomic.StoreUint32(&s.readViewHasData, 0) - } dst = dst.DropFirst(n) if e != nil { err = syserr.FromError(e) break } + // If we are done reading requested data then stop. + if dst.NumBytes() == 0 { + break + } + } + + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) } // If we managed to copy something, we must deliver it. @@ -2812,10 +2658,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq var addr linux.SockAddr var addrLen uint32 if isPacket && senderRequested { - addr, addrLen = ConvertAddress(s.family, s.sender) + addr, addrLen = socket.ConvertAddress(s.family, s.sender) switch v := addr.(type) { case *linux.SockAddrLink: - v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol)) v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) } } @@ -2980,7 +2826,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, family, err := AddressAndFamily(to) + addrBuf, family, err := socket.AddressAndFamily(to) if err != nil { return 0, err } @@ -3399,6 +3245,18 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 { return rv } +func isTCPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_STREAM && (skProto == 0 || skProto == syscall.IPPROTO_TCP) +} + +func isUDPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_DGRAM && (skProto == 0 || skProto == syscall.IPPROTO_UDP) +} + +func isICMPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_DGRAM && (skProto == syscall.IPPROTO_ICMP || skProto == syscall.IPPROTO_ICMPV6) +} + // State implements socket.Socket.State. State translates the internal state // returned by netstack to values defined by Linux. func (s *socketOpsCommon) State() uint32 { @@ -3408,7 +3266,7 @@ func (s *socketOpsCommon) State() uint32 { } switch { - case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP: + case isTCPSocket(s.skType, s.protocol): // TCP socket. switch tcp.EndpointState(s.Endpoint.State()) { case tcp.StateEstablished: @@ -3437,7 +3295,7 @@ func (s *socketOpsCommon) State() uint32 { // Internal or unknown state. return 0 } - case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP: + case isUDPSocket(s.skType, s.protocol): // UDP socket. switch udp.EndpointState(s.Endpoint.State()) { case udp.StateInitial, udp.StateBound, udp.StateClosed: @@ -3447,7 +3305,7 @@ func (s *socketOpsCommon) State() uint32 { default: return 0 } - case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6: + case isICMPSocket(s.skType, s.protocol): // TODO(b/112063468): Export states for ICMP sockets. case s.skType == linux.SOCK_RAW: // TODO(b/112063468): Export states for raw sockets. diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index b0d9e4d9e..b756bfca0 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -51,9 +51,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{}) // NewVFS2 creates a new endpoint socket. func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + endpoint.SocketOptions().SetDelayOption(true) } mnt := t.Kernel().SocketMount() @@ -191,7 +189,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addrLen uint32 if peerAddr != nil { // Get address of the peer and write it to peer slice. - addr, addrLen = ConvertAddress(s.family, *peerAddr) + addr, addrLen = socket.ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index ead3b2b79..c847ff1c7 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -158,7 +158,7 @@ func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol // protocol is passed in network byte order, but netstack wants it in // host order. - netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol))) wq := &waiter.Queue{} ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go index 2a01143f6..0af805246 100644 --- a/pkg/sentry/socket/netstack/provider_vfs2.go +++ b/pkg/sentry/socket/netstack/provider_vfs2.go @@ -102,7 +102,7 @@ func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, prot // protocol is passed in network byte order, but netstack wants it in // host order. - netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol))) wq := &waiter.Queue{} ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index fa9ac9059..cc0fadeb5 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -324,12 +324,12 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { 0, // Support Ip/FragCreates. } case *inet.StatSNMPICMP: - in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats - out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats + in := Metrics.ICMP.V4.PacketsReceived.ICMPv4PacketStats + out := Metrics.ICMP.V4.PacketsSent.ICMPv4PacketStats // TODO(gvisor.dev/issue/969) Support stubbed stats. *stats = inet.StatSNMPICMP{ 0, // Icmp/InMsgs. - Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors. + Metrics.ICMP.V4.PacketsSent.Dropped.Value(), // InErrors. 0, // Icmp/InCsumErrors. in.DstUnreachable.Value(), // InDestUnreachs. in.TimeExceeded.Value(), // InTimeExcds. @@ -343,18 +343,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { in.InfoRequest.Value(), // InAddrMasks. in.InfoReply.Value(), // InAddrMaskReps. 0, // Icmp/OutMsgs. - Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors. - out.DstUnreachable.Value(), // OutDestUnreachs. - out.TimeExceeded.Value(), // OutTimeExcds. - out.ParamProblem.Value(), // OutParmProbs. - out.SrcQuench.Value(), // OutSrcQuenchs. - out.Redirect.Value(), // OutRedirects. - out.Echo.Value(), // OutEchos. - out.EchoReply.Value(), // OutEchoReps. - out.Timestamp.Value(), // OutTimestamps. - out.TimestampReply.Value(), // OutTimestampReps. - out.InfoRequest.Value(), // OutAddrMasks. - out.InfoReply.Value(), // OutAddrMaskReps. + Metrics.ICMP.V4.PacketsReceived.Invalid.Value(), // OutErrors. + out.DstUnreachable.Value(), // OutDestUnreachs. + out.TimeExceeded.Value(), // OutTimeExcds. + out.ParamProblem.Value(), // OutParmProbs. + out.SrcQuench.Value(), // OutSrcQuenchs. + out.Redirect.Value(), // OutRedirects. + out.Echo.Value(), // OutEchos. + out.EchoReply.Value(), // OutEchoReps. + out.Timestamp.Value(), // OutTimestamps. + out.TimestampReply.Value(), // OutTimestampReps. + out.InfoRequest.Value(), // OutAddrMasks. + out.InfoReply.Value(), // OutAddrMaskReps. } case *inet.StatSNMPTCP: tcp := Metrics.TCP diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index fd31479e5..9049e8a21 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -18,6 +18,7 @@ package socket import ( + "bytes" "fmt" "sync/atomic" "syscall" @@ -35,6 +36,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/usermem" ) @@ -460,3 +462,176 @@ func UnmarshalSockAddr(family int, data []byte) linux.SockAddr { panic(fmt.Sprintf("Unsupported socket family %v", family)) } } + +var sockAddrLinkSize = (&linux.SockAddrLink{}).SizeBytes() +var sockAddrInetSize = (&linux.SockAddrInet{}).SizeBytes() +var sockAddrInet6Size = (&linux.SockAddrInet6{}).SizeBytes() + +// Ntohs converts a 16-bit number from network byte order to host byte order. It +// assumes that the host is little endian. +func Ntohs(v uint16) uint16 { + return v<<8 | v>>8 +} + +// Htons converts a 16-bit number from host byte order to network byte order. It +// assumes that the host is little endian. +func Htons(v uint16) uint16 { + return Ntohs(v) +} + +// isLinkLocal determines if the given IPv6 address is link-local. This is the +// case when it has the fe80::/10 prefix. This check is used to determine when +// the NICID is relevant for a given IPv6 address. +func isLinkLocal(addr tcpip.Address) bool { + return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80 +} + +// ConvertAddress converts the given address to a native format. +func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { + switch family { + case linux.AF_UNIX: + var out linux.SockAddrUnix + out.Family = linux.AF_UNIX + l := len([]byte(addr.Addr)) + for i := 0; i < l; i++ { + out.Path[i] = int8(addr.Addr[i]) + } + + // Linux returns the used length of the address struct (including the + // null terminator) for filesystem paths. The Family field is 2 bytes. + // It is sometimes allowed to exclude the null terminator if the + // address length is the max. Abstract and empty paths always return + // the full exact length. + if l == 0 || out.Path[0] == 0 || l == len(out.Path) { + return &out, uint32(2 + l) + } + return &out, uint32(3 + l) + + case linux.AF_INET: + var out linux.SockAddrInet + copy(out.Addr[:], addr.Addr) + out.Family = linux.AF_INET + out.Port = Htons(addr.Port) + return &out, uint32(sockAddrInetSize) + + case linux.AF_INET6: + var out linux.SockAddrInet6 + if len(addr.Addr) == header.IPv4AddressSize { + // Copy address in v4-mapped format. + copy(out.Addr[12:], addr.Addr) + out.Addr[10] = 0xff + out.Addr[11] = 0xff + } else { + copy(out.Addr[:], addr.Addr) + } + out.Family = linux.AF_INET6 + out.Port = Htons(addr.Port) + if isLinkLocal(addr.Addr) { + out.Scope_id = uint32(addr.NIC) + } + return &out, uint32(sockAddrInet6Size) + + case linux.AF_PACKET: + // TODO(gvisor.dev/issue/173): Return protocol too. + var out linux.SockAddrLink + out.Family = linux.AF_PACKET + out.InterfaceIndex = int32(addr.NIC) + out.HardwareAddrLen = header.EthernetAddressSize + copy(out.HardwareAddr[:], addr.Addr) + return &out, uint32(sockAddrLinkSize) + + default: + return nil, 0 + } +} + +// BytesToIPAddress converts an IPv4 or IPv6 address from the user to the +// netstack representation taking any addresses into account. +func BytesToIPAddress(addr []byte) tcpip.Address { + if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) { + return "" + } + return tcpip.Address(addr) +} + +// AddressAndFamily reads an sockaddr struct from the given address and +// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, +// AF_INET6, and AF_PACKET addresses. +// +// AddressAndFamily returns an address and its family. +func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { + // Make sure we have at least 2 bytes for the address family. + if len(addr) < 2 { + return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument + } + + // Get the rest of the fields based on the address family. + switch family := usermem.ByteOrder.Uint16(addr); family { + case linux.AF_UNIX: + path := addr[2:] + if len(path) > linux.UnixPathMax { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + // Drop the terminating NUL (if one exists) and everything after + // it for filesystem (non-abstract) addresses. + if len(path) > 0 && path[0] != 0 { + if n := bytes.IndexByte(path[1:], 0); n >= 0 { + path = path[:n+1] + } + } + return tcpip.FullAddress{ + Addr: tcpip.Address(path), + }, family, nil + + case linux.AF_INET: + var a linux.SockAddrInet + if len(addr) < sockAddrInetSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) + + out := tcpip.FullAddress{ + Addr: BytesToIPAddress(a.Addr[:]), + Port: Ntohs(a.Port), + } + return out, family, nil + + case linux.AF_INET6: + var a linux.SockAddrInet6 + if len(addr) < sockAddrInet6Size { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) + + out := tcpip.FullAddress{ + Addr: BytesToIPAddress(a.Addr[:]), + Port: Ntohs(a.Port), + } + if isLinkLocal(out.Addr) { + out.NIC = tcpip.NICID(a.Scope_id) + } + return out, family, nil + + case linux.AF_PACKET: + var a linux.SockAddrLink + if len(addr) < sockAddrLinkSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) + if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { + return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/173): Return protocol too. + return tcpip.FullAddress{ + NIC: tcpip.NICID(a.InterfaceIndex), + Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + }, family, nil + + case linux.AF_UNSPEC: + return tcpip.FullAddress{}, family, nil + + default: + return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported + } +} diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 6d9e502bd..9f7aca305 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -118,28 +118,24 @@ var ( // NewConnectioned creates a new unbound connectionedEndpoint. func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint { - return &connectionedEndpoint{ + return newConnectioned(ctx, stype, uid) +} + +func newConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) *connectionedEndpoint { + ep := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), idGenerator: uid, stype: stype, } + ep.ops.InitHandler(ep) + return ep } // NewPair allocates a new pair of connected unix-domain connectionedEndpoints. func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { - a := &connectionedEndpoint{ - baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, - id: uid.UniqueID(), - idGenerator: uid, - stype: stype, - } - b := &connectionedEndpoint{ - baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, - id: uid.UniqueID(), - idGenerator: uid, - stype: stype, - } + a := newConnectioned(ctx, stype, uid) + b := newConnectioned(ctx, stype, uid) q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} q1.InitRefs() @@ -171,12 +167,14 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E // NewExternal creates a new externally backed Endpoint. It behaves like a // socketpair. func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { - return &connectionedEndpoint{ + ep := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected}, id: uid.UniqueID(), idGenerator: uid, stype: stype, } + ep.ops.InitHandler(ep) + return ep } // ID implements ConnectingEndpoint.ID. @@ -298,6 +296,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn idGenerator: e.idGenerator, stype: e.stype, } + ne.ops.InitHandler(ne) readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit} readQueue.InitRefs() diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 1406971bc..0813ad87d 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -44,6 +44,7 @@ func NewConnectionless(ctx context.Context) Endpoint { q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit} q.InitRefs() ep.receiver = &queueReceiver{readQueue: &q} + ep.ops.InitHandler(ep) return ep } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 18a50e9f8..0247e93fa 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -16,8 +16,6 @@ package transport import ( - "sync/atomic" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" @@ -180,10 +178,6 @@ type Endpoint interface { // SetSockOpt sets a socket option. SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error - // SetSockOptBool sets a socket option for simple cases when a value has - // the int type. - SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error - // SetSockOptInt sets a socket option for simple cases when a value has // the int type. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -191,10 +185,6 @@ type Endpoint interface { // GetSockOpt gets a socket option. GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error - // GetSockOptBool gets a socket option for simple cases when a return - // value has the int type. - GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) - // GetSockOptInt gets a socket option for simple cases when a return // value has the int type. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) @@ -203,10 +193,11 @@ type Endpoint interface { // procfs. State() uint32 - // LastError implements tcpip.Endpoint.LastError. + // LastError clears and returns the last error reported by the endpoint. LastError() *tcpip.Error - // SocketOptions implements tcpip.Endpoint.SocketOptions. + // SocketOptions returns the structure which contains all the socket + // level options. SocketOptions() *tcpip.SocketOptions } @@ -739,10 +730,7 @@ func (e *connectedEndpoint) CloseUnread() { // +stateify savable type baseEndpoint struct { *waiter.Queue - - // passcred specifies whether SCM_CREDENTIALS socket control messages are - // enabled on this endpoint. Must be accessed atomically. - passcred int32 + tcpip.DefaultSocketOptionsHandler // Mutex protects the below fields. sync.Mutex `state:"nosave"` @@ -761,6 +749,7 @@ type baseEndpoint struct { // linger is used for SO_LINGER socket option. linger tcpip.LingerOption + // ops is used to get socket level options. ops tcpip.SocketOptions } @@ -786,7 +775,7 @@ func (e *baseEndpoint) EventUnregister(we *waiter.Entry) { // Passcred implements Credentialer.Passcred. func (e *baseEndpoint) Passcred() bool { - return atomic.LoadInt32(&e.passcred) != 0 + return e.SocketOptions().GetPassCred() } // ConnectedPasscred implements Credentialer.ConnectedPasscred. @@ -796,14 +785,6 @@ func (e *baseEndpoint) ConnectedPasscred() bool { return e.connected != nil && e.connected.Passcred() } -func (e *baseEndpoint) setPasscred(pc bool) { - if pc { - atomic.StoreInt32(&e.passcred, 1) - } else { - atomic.StoreInt32(&e.passcred, 0) - } -} - // Connected implements ConnectingEndpoint.Connected. func (e *baseEndpoint) Connected() bool { return e.receiver != nil && e.connected != nil @@ -868,17 +849,6 @@ func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error { return nil } -func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { - switch opt { - case tcpip.PasscredOption: - e.setPasscred(v) - case tcpip.ReuseAddressOption: - default: - log.Warningf("Unsupported socket option: %d", opt) - } - return nil -} - func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { switch opt { case tcpip.SendBufferSizeOption: @@ -889,20 +859,6 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { return nil } -func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - switch opt { - case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: - return false, nil - - case tcpip.PasscredOption: - return e.Passcred(), nil - - default: - log.Warningf("Unsupported socket option: %d", opt) - return false, tcpip.ErrUnknownProtocolOption - } -} - func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 3e520d2ee..c59297c80 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -115,9 +115,6 @@ type socketOpsCommon struct { // bound, they cannot be modified. abstractName string abstractNamespace *kernel.AbstractSocketNamespace - - // ops is used to get socket level options. - ops tcpip.SocketOptions } func (s *socketOpsCommon) isPacket() bool { @@ -139,7 +136,7 @@ func (s *socketOpsCommon) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, family, err := netstack.AddressAndFamily(sockaddr) + addr, family, err := socket.AddressAndFamily(sockaddr) if err != nil { if err == syserr.ErrAddressFamilyNotSupported { err = syserr.ErrInvalidArgument @@ -172,7 +169,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) + a, l := socket.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } @@ -184,7 +181,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) + a, l := socket.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } @@ -258,7 +255,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) + addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -650,7 +647,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { - from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { @@ -685,7 +682,7 @@ func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil { - from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = socket.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index eaf0b0d26..27f705bb2 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -172,7 +172,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr) + addr, addrLen = socket.ConvertAddress(linux.AF_UNIX, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ |