diff options
Diffstat (limited to 'pkg/sentry/socket/netstack')
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 519 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack_vfs2.go | 6 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/provider.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/provider_vfs2.go | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/stack.go | 30 |
5 files changed, 215 insertions, 344 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index e8a0103bf..5e9ab97ad 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -84,69 +84,73 @@ var Metrics = tcpip.Stats{ MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."), DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."), ICMP: tcpip.ICMPStats{ - V4PacketsSent: tcpip.ICMPv4SentPacketStats{ - ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."), + V4: tcpip.ICMPv4Stats{ + PacketsSent: tcpip.ICMPv4SentPacketStats{ + ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ + Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."), + }, + Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."), }, - Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."), - }, - V4PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ - ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ - Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."), - SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."), - Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."), - Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."), - TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."), - InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."), - InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."), + PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{ + ICMPv4PacketStats: tcpip.ICMPv4PacketStats{ + Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."), + SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."), + Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."), + Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."), + TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."), + InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."), + InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."), + }, + Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."), }, - Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."), }, - V6PacketsSent: tcpip.ICMPv6SentPacketStats{ - ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."), + V6: tcpip.ICMPv6Stats{ + PacketsSent: tcpip.ICMPv6SentPacketStats{ + ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."), + }, + Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."), }, - Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."), - }, - V6PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ - ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ - EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."), - EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."), - DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."), - PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."), - TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."), - ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."), - RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."), - RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."), - NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."), - NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."), - RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."), + PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{ + ICMPv6PacketStats: tcpip.ICMPv6PacketStats{ + EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."), + EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."), + DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."), + PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."), + TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."), + ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."), + RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."), + RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."), + NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."), + NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."), + RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."), + }, + Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."), }, - Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."), }, }, IP: tcpip.IPStats{ @@ -209,18 +213,6 @@ const sizeOfInt32 int = 4 var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL) -// ntohs converts a 16-bit number from network byte order to host byte order. It -// assumes that the host is little endian. -func ntohs(v uint16) uint16 { - return v<<8 | v>>8 -} - -// htons converts a 16-bit number from host byte order to network byte order. It -// assumes that the host is little endian. -func htons(v uint16) uint16 { - return ntohs(v) -} - // commonEndpoint represents the intersection of a tcpip.Endpoint and a // transport.Endpoint. type commonEndpoint interface { @@ -240,10 +232,6 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error - // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and - // transport.Endpoint.SetSockOptBool. - SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error - // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and // transport.Endpoint.SetSockOptInt. SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error @@ -252,14 +240,14 @@ type commonEndpoint interface { // transport.Endpoint.GetSockOpt. GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error - // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and - // transport.Endpoint.GetSockOpt. - GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) - // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) + // State returns a socket's lifecycle state. The returned value is + // protocol-specific and is primarily used for diagnostics. + State() uint32 + // LastError implements tcpip.Endpoint.LastError and // transport.Endpoint.LastError. LastError() *tcpip.Error @@ -334,9 +322,7 @@ type socketOpsCommon struct { // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + endpoint.SocketOptions().SetDelayOption(true) } dirent := socket.NewDirent(t, netstackDevice) @@ -365,88 +351,6 @@ func bytesToIPAddress(addr []byte) tcpip.Address { return tcpip.Address(addr) } -// AddressAndFamily reads an sockaddr struct from the given address and -// converts it to the FullAddress format. It supports AF_UNIX, AF_INET, -// AF_INET6, and AF_PACKET addresses. -// -// AddressAndFamily returns an address and its family. -func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { - // Make sure we have at least 2 bytes for the address family. - if len(addr) < 2 { - return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument - } - - // Get the rest of the fields based on the address family. - switch family := usermem.ByteOrder.Uint16(addr); family { - case linux.AF_UNIX: - path := addr[2:] - if len(path) > linux.UnixPathMax { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - // Drop the terminating NUL (if one exists) and everything after - // it for filesystem (non-abstract) addresses. - if len(path) > 0 && path[0] != 0 { - if n := bytes.IndexByte(path[1:], 0); n >= 0 { - path = path[:n+1] - } - } - return tcpip.FullAddress{ - Addr: tcpip.Address(path), - }, family, nil - - case linux.AF_INET: - var a linux.SockAddrInet - if len(addr) < sockAddrInetSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a) - - out := tcpip.FullAddress{ - Addr: bytesToIPAddress(a.Addr[:]), - Port: ntohs(a.Port), - } - return out, family, nil - - case linux.AF_INET6: - var a linux.SockAddrInet6 - if len(addr) < sockAddrInet6Size { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a) - - out := tcpip.FullAddress{ - Addr: bytesToIPAddress(a.Addr[:]), - Port: ntohs(a.Port), - } - if isLinkLocal(out.Addr) { - out.NIC = tcpip.NICID(a.Scope_id) - } - return out, family, nil - - case linux.AF_PACKET: - var a linux.SockAddrLink - if len(addr) < sockAddrLinkSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a) - if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize { - return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument - } - - // TODO(gvisor.dev/issue/173): Return protocol too. - return tcpip.FullAddress{ - NIC: tcpip.NICID(a.InterfaceIndex), - Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), - }, family, nil - - case linux.AF_UNSPEC: - return tcpip.FullAddress{}, family, nil - - default: - return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported - } -} - func (s *socketOpsCommon) isPacketBased() bool { return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW } @@ -723,11 +627,7 @@ func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error { return nil } if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 { - v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption) - if err != nil { - return syserr.TranslateNetstackError(err) - } - if !v { + if !s.Endpoint.SocketOptions().GetV6Only() { return nil } } @@ -751,7 +651,7 @@ func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, family, err := AddressAndFamily(sockaddr) + addr, family, err := socket.AddressAndFamily(sockaddr) if err != nil { return err } @@ -832,7 +732,7 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } } else { var err *syserr.Error - addr, family, err = AddressAndFamily(sockaddr) + addr, family, err = socket.AddressAndFamily(sockaddr) if err != nil { return err } @@ -923,7 +823,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, var addr linux.SockAddr var addrLen uint32 if peerAddr != nil { - addr, addrLen = ConvertAddress(s.family, *peerAddr) + addr, addrLen = socket.ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{ @@ -1007,7 +907,7 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in return getSockOptSocket(t, s, ep, family, skType, name, outLen) case linux.SOL_TCP: - return getSockOptTCP(t, ep, name, outLen) + return getSockOptTCP(t, s, ep, name, outLen) case linux.SOL_IPV6: return getSockOptIPv6(t, s, ep, name, outPtr, outLen) @@ -1226,8 +1126,13 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam return nil, syserr.ErrInvalidArgument } - v := primitive.Int32(boolToInt32(ep.SocketOptions().GetAcceptConn())) - return &v, nil + // This option is only viable for TCP endpoints. + var v bool + if _, skType, skProto := s.Type(); isTCPSocket(skType, skProto) { + v = tcp.EndpointState(ep.State()) == tcp.StateListen + } + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -1236,46 +1141,36 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { + if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) { + log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto) + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.DelayOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(!v)) - return &vP, nil + v := primitive.Int32(boolToInt32(!ep.SocketOptions().GetDelayOption())) + return &v, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.CorkOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetCorkOption())) + return &v, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.QuickAckOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetQuickAck())) + return &v, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { @@ -1449,19 +1344,24 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return nil, syserr.ErrUnknownProtocolOption + } + + family, skType, _ := s.Type() + if family != linux.AF_INET6 { + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.V6OnlyOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetV6Only())) + return &v, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1493,13 +1393,8 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTClass())) + return &v, nil case linux.IP6T_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet6{})) { @@ -1511,7 +1406,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) + a, _ := socket.ConvertAddress(linux.AF_INET6, tcpip.FullAddress(v)) return a.(*linux.SockAddrInet6), nil case linux.IP6T_SO_GET_INFO: @@ -1520,7 +1415,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1540,7 +1435,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return nil, syserr.ErrInvalidArgument } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1560,7 +1455,7 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return nil, syserr.ErrProtocolNotAvailable } @@ -1582,6 +1477,11 @@ func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name // getSockOptIP implements GetSockOpt when level is SOL_IP. func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return nil, syserr.ErrUnknownProtocolOption + } + switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { @@ -1624,7 +1524,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) + a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) return &a.(*linux.SockAddrInet).Addr, nil @@ -1633,13 +1533,8 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetMulticastLoop())) + return &v, nil case linux.IP_TOS: // Length handling for parity with Linux. @@ -1663,26 +1558,24 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) - } - - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceiveTOS())) + return &v, nil case linux.IP_PKTINFO: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption) - if err != nil { - return nil, syserr.TranslateNetstackError(err) + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetReceivePacketInfo())) + return &v, nil + + case linux.IP_HDRINCL: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - vP := primitive.Int32(boolToInt32(v)) - return &vP, nil + v := primitive.Int32(boolToInt32(ep.SocketOptions().GetHeaderIncluded())) + return &v, nil case linux.SO_ORIGINAL_DST: if outLen < int(binary.Size(linux.SockAddrInet{})) { @@ -1694,7 +1587,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return nil, syserr.TranslateNetstackError(err) } - a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) + a, _ := socket.ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) return a.(*linux.SockAddrInet), nil case linux.IPT_SO_GET_INFO: @@ -1801,7 +1694,7 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int return setSockOptSocket(t, s, ep, name, optVal) case linux.SOL_TCP: - return setSockOptTCP(t, ep, name, optVal) + return setSockOptTCP(t, s, ep, name, optVal) case linux.SOL_IPV6: return setSockOptIPv6(t, s, ep, name, optVal) @@ -1991,7 +1884,12 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } // setSockOptTCP implements SetSockOpt when level is SOL_TCP. -func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error { +func setSockOptTCP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, skType, skProto := s.Type(); !isTCPSocket(skType, skProto) { + log.Warningf("SOL_TCP options are only supported on TCP sockets: skType, skProto = %v, %d", skType, skProto) + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.TCP_NODELAY: if len(optVal) < sizeOfInt32 { @@ -1999,7 +1897,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0)) + ep.SocketOptions().SetDelayOption(v == 0) + return nil case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -2007,7 +1906,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0)) + ep.SocketOptions().SetCorkOption(v != 0) + return nil case linux.TCP_QUICKACK: if len(optVal) < sizeOfInt32 { @@ -2015,7 +1915,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0)) + ep.SocketOptions().SetQuickAck(v != 0) + return nil case linux.TCP_MAXSEG: if len(optVal) < sizeOfInt32 { @@ -2127,14 +2028,31 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * // setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6. func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IPV6 options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return syserr.ErrUnknownProtocolOption + } + + family, skType, skProto := s.Type() + if family != linux.AF_INET6 { + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.IPV6_V6ONLY: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument } + if isTCPSocket(skType, skProto) && tcp.EndpointState(ep.State()) != tcp.StateInitial { + return syserr.ErrInvalidEndpointState + } else if isUDPSocket(skType, skProto) && udp.EndpointState(ep.State()) != udp.StateInitial { + return syserr.ErrInvalidEndpointState + } + v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0)) + ep.SocketOptions().SetV6Only(v != 0) + return nil case linux.IPV6_ADD_MEMBERSHIP, linux.IPV6_DROP_MEMBERSHIP, @@ -2173,7 +2091,8 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0)) + ep.SocketOptions().SetReceiveTClass(v != 0) + return nil case linux.IP6T_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIP6TReplace { @@ -2181,7 +2100,7 @@ func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name } // Only valid for raw IPv6 sockets. - if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + if skType != linux.SOCK_RAW { return syserr.ErrProtocolNotAvailable } @@ -2256,6 +2175,11 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) { // setSockOptIP implements SetSockOpt when level is SOL_IP. func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { + if _, ok := ep.(tcpip.Endpoint); !ok { + log.Warningf("SOL_IP options not supported on endpoints other than tcpip.Endpoint: option = %d", name) + return syserr.ErrUnknownProtocolOption + } + switch name { case linux.IP_MULTICAST_TTL: v, err := parseIntOrChar(optVal) @@ -2308,7 +2232,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.MulticastInterfaceOption{ NIC: tcpip.NICID(req.InterfaceIndex), - InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]), + InterfaceAddr: socket.BytesToIPAddress(req.InterfaceAddr[:]), })) case linux.IP_MULTICAST_LOOP: @@ -2317,7 +2241,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0)) + ep.SocketOptions().SetMulticastLoop(v != 0) + return nil case linux.MCAST_JOIN_GROUP: // FIXME(b/124219304): Implement MCAST_JOIN_GROUP. @@ -2353,7 +2278,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0)) + ep.SocketOptions().SetReceiveTOS(v != 0) + return nil case linux.IP_PKTINFO: if len(optVal) == 0 { @@ -2363,7 +2289,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + ep.SocketOptions().SetReceivePacketInfo(v != 0) + return nil case linux.IP_HDRINCL: if len(optVal) == 0 { @@ -2373,7 +2300,8 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) + ep.SocketOptions().SetHeaderIncluded(v != 0) + return nil case linux.IPT_SO_SET_REPLACE: if len(optVal) < linux.SizeOfIPTReplace { @@ -2515,7 +2443,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { switch name { case linux.IP_TOS, linux.IP_TTL, - linux.IP_HDRINCL, linux.IP_OPTIONS, linux.IP_ROUTER_ALERT, linux.IP_RECVOPTS, @@ -2562,72 +2489,6 @@ func emitUnimplementedEventIP(t *kernel.Task, name int) { } } -// isLinkLocal determines if the given IPv6 address is link-local. This is the -// case when it has the fe80::/10 prefix. This check is used to determine when -// the NICID is relevant for a given IPv6 address. -func isLinkLocal(addr tcpip.Address) bool { - return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80 -} - -// ConvertAddress converts the given address to a native format. -func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) { - switch family { - case linux.AF_UNIX: - var out linux.SockAddrUnix - out.Family = linux.AF_UNIX - l := len([]byte(addr.Addr)) - for i := 0; i < l; i++ { - out.Path[i] = int8(addr.Addr[i]) - } - - // Linux returns the used length of the address struct (including the - // null terminator) for filesystem paths. The Family field is 2 bytes. - // It is sometimes allowed to exclude the null terminator if the - // address length is the max. Abstract and empty paths always return - // the full exact length. - if l == 0 || out.Path[0] == 0 || l == len(out.Path) { - return &out, uint32(2 + l) - } - return &out, uint32(3 + l) - - case linux.AF_INET: - var out linux.SockAddrInet - copy(out.Addr[:], addr.Addr) - out.Family = linux.AF_INET - out.Port = htons(addr.Port) - return &out, uint32(sockAddrInetSize) - - case linux.AF_INET6: - var out linux.SockAddrInet6 - if len(addr.Addr) == header.IPv4AddressSize { - // Copy address in v4-mapped format. - copy(out.Addr[12:], addr.Addr) - out.Addr[10] = 0xff - out.Addr[11] = 0xff - } else { - copy(out.Addr[:], addr.Addr) - } - out.Family = linux.AF_INET6 - out.Port = htons(addr.Port) - if isLinkLocal(addr.Addr) { - out.Scope_id = uint32(addr.NIC) - } - return &out, uint32(sockAddrInet6Size) - - case linux.AF_PACKET: - // TODO(gvisor.dev/issue/173): Return protocol too. - var out linux.SockAddrLink - out.Family = linux.AF_PACKET - out.InterfaceIndex = int32(addr.NIC) - out.HardwareAddrLen = header.EthernetAddressSize - copy(out.HardwareAddr[:], addr.Addr) - return &out, uint32(sockAddrLinkSize) - - default: - return nil, 0 - } -} - // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { @@ -2636,7 +2497,7 @@ func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := ConvertAddress(s.family, addr) + a, l := socket.ConvertAddress(s.family, addr) return a, l, nil } @@ -2648,7 +2509,7 @@ func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, * return nil, 0, syserr.TranslateNetstackError(err) } - a, l := ConvertAddress(s.family, addr) + a, l := socket.ConvertAddress(s.family, addr) return a, l, nil } @@ -2797,10 +2658,10 @@ func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSeq var addr linux.SockAddr var addrLen uint32 if isPacket && senderRequested { - addr, addrLen = ConvertAddress(s.family, s.sender) + addr, addrLen = socket.ConvertAddress(s.family, s.sender) switch v := addr.(type) { case *linux.SockAddrLink: - v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + v.Protocol = socket.Htons(uint16(s.linkPacketInfo.Protocol)) v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) } } @@ -2965,7 +2826,7 @@ func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []b var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, family, err := AddressAndFamily(to) + addrBuf, family, err := socket.AddressAndFamily(to) if err != nil { return 0, err } @@ -3384,6 +3245,18 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 { return rv } +func isTCPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_STREAM && (skProto == 0 || skProto == syscall.IPPROTO_TCP) +} + +func isUDPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_DGRAM && (skProto == 0 || skProto == syscall.IPPROTO_UDP) +} + +func isICMPSocket(skType linux.SockType, skProto int) bool { + return skType == linux.SOCK_DGRAM && (skProto == syscall.IPPROTO_ICMP || skProto == syscall.IPPROTO_ICMPV6) +} + // State implements socket.Socket.State. State translates the internal state // returned by netstack to values defined by Linux. func (s *socketOpsCommon) State() uint32 { @@ -3393,7 +3266,7 @@ func (s *socketOpsCommon) State() uint32 { } switch { - case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP: + case isTCPSocket(s.skType, s.protocol): // TCP socket. switch tcp.EndpointState(s.Endpoint.State()) { case tcp.StateEstablished: @@ -3422,7 +3295,7 @@ func (s *socketOpsCommon) State() uint32 { // Internal or unknown state. return 0 } - case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP: + case isUDPSocket(s.skType, s.protocol): // UDP socket. switch udp.EndpointState(s.Endpoint.State()) { case udp.StateInitial, udp.StateBound, udp.StateClosed: @@ -3432,7 +3305,7 @@ func (s *socketOpsCommon) State() uint32 { default: return 0 } - case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6: + case isICMPSocket(s.skType, s.protocol): // TODO(b/112063468): Export states for ICMP sockets. case s.skType == linux.SOCK_RAW: // TODO(b/112063468): Export states for raw sockets. diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index b0d9e4d9e..b756bfca0 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -51,9 +51,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{}) // NewVFS2 creates a new endpoint socket. func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { - return nil, syserr.TranslateNetstackError(err) - } + endpoint.SocketOptions().SetDelayOption(true) } mnt := t.Kernel().SocketMount() @@ -191,7 +189,7 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block var addrLen uint32 if peerAddr != nil { // Get address of the peer and write it to peer slice. - addr, addrLen = ConvertAddress(s.family, *peerAddr) + addr, addrLen = socket.ConvertAddress(s.family, *peerAddr) } fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index ead3b2b79..c847ff1c7 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -158,7 +158,7 @@ func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol // protocol is passed in network byte order, but netstack wants it in // host order. - netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol))) wq := &waiter.Queue{} ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go index 2a01143f6..0af805246 100644 --- a/pkg/sentry/socket/netstack/provider_vfs2.go +++ b/pkg/sentry/socket/netstack/provider_vfs2.go @@ -102,7 +102,7 @@ func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, prot // protocol is passed in network byte order, but netstack wants it in // host order. - netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + netProto := tcpip.NetworkProtocolNumber(socket.Ntohs(uint16(protocol))) wq := &waiter.Queue{} ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index fa9ac9059..cc0fadeb5 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -324,12 +324,12 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { 0, // Support Ip/FragCreates. } case *inet.StatSNMPICMP: - in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats - out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats + in := Metrics.ICMP.V4.PacketsReceived.ICMPv4PacketStats + out := Metrics.ICMP.V4.PacketsSent.ICMPv4PacketStats // TODO(gvisor.dev/issue/969) Support stubbed stats. *stats = inet.StatSNMPICMP{ 0, // Icmp/InMsgs. - Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors. + Metrics.ICMP.V4.PacketsSent.Dropped.Value(), // InErrors. 0, // Icmp/InCsumErrors. in.DstUnreachable.Value(), // InDestUnreachs. in.TimeExceeded.Value(), // InTimeExcds. @@ -343,18 +343,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { in.InfoRequest.Value(), // InAddrMasks. in.InfoReply.Value(), // InAddrMaskReps. 0, // Icmp/OutMsgs. - Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors. - out.DstUnreachable.Value(), // OutDestUnreachs. - out.TimeExceeded.Value(), // OutTimeExcds. - out.ParamProblem.Value(), // OutParmProbs. - out.SrcQuench.Value(), // OutSrcQuenchs. - out.Redirect.Value(), // OutRedirects. - out.Echo.Value(), // OutEchos. - out.EchoReply.Value(), // OutEchoReps. - out.Timestamp.Value(), // OutTimestamps. - out.TimestampReply.Value(), // OutTimestampReps. - out.InfoRequest.Value(), // OutAddrMasks. - out.InfoReply.Value(), // OutAddrMaskReps. + Metrics.ICMP.V4.PacketsReceived.Invalid.Value(), // OutErrors. + out.DstUnreachable.Value(), // OutDestUnreachs. + out.TimeExceeded.Value(), // OutTimeExcds. + out.ParamProblem.Value(), // OutParmProbs. + out.SrcQuench.Value(), // OutSrcQuenchs. + out.Redirect.Value(), // OutRedirects. + out.Echo.Value(), // OutEchos. + out.EchoReply.Value(), // OutEchoReps. + out.Timestamp.Value(), // OutTimestamps. + out.TimestampReply.Value(), // OutTimestampReps. + out.InfoRequest.Value(), // OutAddrMasks. + out.InfoReply.Value(), // OutAddrMaskReps. } case *inet.StatSNMPTCP: tcp := Metrics.TCP |