diff options
author | Ian Lewis <ianmlewis@gmail.com> | 2020-08-17 21:44:31 -0400 |
---|---|---|
committer | Ian Lewis <ianmlewis@gmail.com> | 2020-08-17 21:44:31 -0400 |
commit | ac324f646ee3cb7955b0b45a7453aeb9671cbdf1 (patch) | |
tree | 0cbc5018e8807421d701d190dc20525726c7ca76 /pkg/sentry/socket/netstack/netstack.go | |
parent | 352ae1022ce19de28fc72e034cc469872ad79d06 (diff) | |
parent | 6d0c5803d557d453f15ac6f683697eeb46dab680 (diff) |
Merge branch 'master' into ip-forwarding
- Merges aleksej-paschenko's with HEAD
- Adds vfs2 support for ip_forward
Diffstat (limited to 'pkg/sentry/socket/netstack/netstack.go')
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go | 1064 |
1 files changed, 785 insertions, 279 deletions
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 27c6692c4..e4846bc0b 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -26,29 +26,32 @@ package netstack import ( "bytes" + "fmt" "io" "math" "reflect" - "sync" + "sync/atomic" "syscall" "time" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/unimpl" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -57,12 +60,21 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) func mustCreateMetric(name, description string) *tcpip.StatCounter { var cm tcpip.StatCounter - metric.MustRegisterCustomUint64Metric(name, false /* sync */, description, cm.Value) + metric.MustRegisterCustomUint64Metric(name, true /* cumulative */, false /* sync */, description, cm.Value) + return &cm +} + +func mustCreateGauge(name, description string) *tcpip.StatCounter { + var cm tcpip.StatCounter + metric.MustRegisterCustomUint64Metric(name, false /* cumulative */, false /* sync */, description, cm.Value) return &cm } @@ -138,19 +150,23 @@ var Metrics = tcpip.Stats{ }, }, IP: tcpip.IPStats{ - PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), - InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), - PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), - PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), - OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), - MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."), - MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."), + PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), + InvalidDestinationAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), + InvalidSourceAddressesReceived: mustCreateMetric("/netstack/ip/invalid_source_addresses_received", "Total number of IP packets received with an unknown or invalid source address."), + PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), + PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), + OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), + MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."), + MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."), }, TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), - CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."), + CurrentEstablished: mustCreateGauge("/netstack/tcp/current_established", "Number of connections in ESTABLISHED state now."), + CurrentConnected: mustCreateGauge("/netstack/tcp/current_open", "Number of connections that are in connected state."), EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"), + EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "Number of times established TCP connections made a transition to CLOSED state."), + EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."), ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."), ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."), ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."), @@ -178,6 +194,8 @@ var Metrics = tcpip.Stats{ MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."), PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."), PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."), + ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."), + InvalidSourceAddress: mustCreateMetric("/netstack/udp/invalid_source", "Number of UDP datagrams dropped due to invalid source address."), }, } @@ -220,19 +238,29 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(interface{}) *tcpip.Error + // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and + // transport.Endpoint.SetSockOptBool. + SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and // transport.Endpoint.SetSockOptInt. - SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error + // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and + // transport.Endpoint.GetSockOpt. + GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. - GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) + GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) } +// LINT.IfChange + // SocketOperations encapsulates all the state needed to represent a network stack // endpoint in the kernel context. // @@ -244,6 +272,14 @@ type SocketOperations struct { fsutil.FileNoFsync `state:"nosave"` fsutil.FileNoMMap `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + socketOpsCommon +} + +// socketOpsCommon contains the socket operations common to VFS1 and VFS2. +// +// +stateify savable +type socketOpsCommon struct { socket.SendReceiveTimeout *waiter.Queue @@ -252,14 +288,21 @@ type SocketOperations struct { skType linux.SockType protocol int + // readViewHasData is 1 iff readView has data to be read, 0 otherwise. + // Must be accessed using atomic operations. It must only be written + // with readMu held but can be read without holding readMu. The latter + // is required to avoid deadlocks in epoll Readiness checks. + readViewHasData uint32 + // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` // readView contains the remaining payload from the last packet. readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. - readCM tcpip.ControlMessages - sender tcpip.FullAddress + readCM tcpip.ControlMessages + sender tcpip.FullAddress + linkPacketInfo tcpip.LinkPacketInfo // sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps // of returned messages can be returned via control messages. When @@ -281,19 +324,21 @@ type SocketOperations struct { // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil { + if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { return nil, syserr.TranslateNetstackError(err) } } dirent := socket.NewDirent(t, netstackDevice) - defer dirent.DecRef() + defer dirent.DecRef(t) return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, &SocketOperations{ - Queue: queue, - family: family, - Endpoint: endpoint, - skType: skType, - protocol: protocol, + socketOpsCommon: socketOpsCommon{ + Queue: queue, + family: family, + Endpoint: endpoint, + skType: skType, + protocol: protocol, + }, }), nil } @@ -314,22 +359,15 @@ func bytesToIPAddress(addr []byte) tcpip.Address { // converts it to the FullAddress format. It supports AF_UNIX, AF_INET, // AF_INET6, and AF_PACKET addresses. // -// strict indicates whether addresses with the AF_UNSPEC family are accepted of not. -// // AddressAndFamily returns an address and its family. -func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) { +func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { // Make sure we have at least 2 bytes for the address family. if len(addr) < 2 { return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument } - family := usermem.ByteOrder.Uint16(addr) - if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { - return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported - } - // Get the rest of the fields based on the address family. - switch family { + switch family := usermem.ByteOrder.Uint16(addr); family { case linux.AF_UNIX: path := addr[2:] if len(path) > linux.UnixPathMax { @@ -385,7 +423,7 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - // TODO(b/129292371): Return protocol too. + // TODO(gvisor.dev/issue/173): Return protocol too. return tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), @@ -399,33 +437,49 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, } } -func (s *SocketOperations) isPacketBased() bool { +func (s *socketOpsCommon) isPacketBased() bool { return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW } // fetchReadView updates the readView field of the socket if it's currently // empty. It assumes that the socket is locked. -func (s *SocketOperations) fetchReadView() *syserr.Error { +// +// Precondition: s.readMu must be held. +func (s *socketOpsCommon) fetchReadView() *syserr.Error { if len(s.readView) > 0 { return nil } - s.readView = nil s.sender = tcpip.FullAddress{} + s.linkPacketInfo = tcpip.LinkPacketInfo{} - v, cms, err := s.Endpoint.Read(&s.sender) + var v buffer.View + var cms tcpip.ControlMessages + var err *tcpip.Error + + switch e := s.Endpoint.(type) { + // The ordering of these interfaces matters. The most specific + // interfaces must be specified before the more generic Endpoint + // interface. + case tcpip.PacketEndpoint: + v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo) + case tcpip.Endpoint: + v, cms, err = e.Read(&s.sender) + } if err != nil { + atomic.StoreUint32(&s.readViewHasData, 0) return syserr.TranslateNetstackError(err) } s.readView = v s.readCM = cms + atomic.StoreUint32(&s.readViewHasData, 1) return nil } // Release implements fs.FileOperations.Release. -func (s *SocketOperations) Release() { +func (s *socketOpsCommon) Release(context.Context) { s.Endpoint.Close() } @@ -520,11 +574,9 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO } if resCh != nil { - t := ctx.(*kernel.Task) - if err := t.Block(resCh); err != nil { - return 0, syserr.FromError(err).ToError() + if err := amutex.Block(ctx, resCh); err != nil { + return 0, err } - n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{}) } @@ -593,11 +645,9 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader } if resCh != nil { - t := ctx.(*kernel.Task) - if err := t.Block(resCh); err != nil { - return 0, syserr.FromError(err).ToError() + if err := amutex.Block(ctx, resCh); err != nil { + return 0, err } - n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{ Atomic: true, // See above. }) @@ -612,26 +662,54 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader } // Readiness returns a mask of ready events for socket s. -func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { +func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { r := s.Endpoint.Readiness(mask) // Check our cached value iff the caller asked for readability and the // endpoint itself is currently not readable. if (mask & ^r & waiter.EventIn) != 0 { - s.readMu.Lock() - if len(s.readView) > 0 { + if atomic.LoadUint32(&s.readViewHasData) == 1 { r |= waiter.EventIn } - s.readMu.Unlock() } return r } +func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error { + if family == uint16(s.family) { + return nil + } + if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 { + v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption) + if err != nil { + return syserr.TranslateNetstackError(err) + } + if !v { + return nil + } + } + return syserr.ErrInvalidArgument +} + +// mapFamily maps the AF_INET ANY address to the IPv4-mapped IPv6 ANY if the +// receiver's family is AF_INET6. +// +// This is a hack to work around the fact that both IPv4 and IPv6 ANY are +// represented by the empty string. +// +// TODO(gvisor.dev/issue/1556): remove this function. +func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress { + if len(addr.Addr) == 0 && s.family == linux.AF_INET6 && family == linux.AF_INET { + addr.Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + } + return addr +} + // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. -func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, family, err := AddressAndFamily(s.family, sockaddr, false /* strict */) +func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { + addr, family, err := AddressAndFamily(sockaddr) if err != nil { return err } @@ -643,6 +721,12 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo } return syserr.TranslateNetstackError(err) } + + if err := s.checkFamily(family, false /* exact */); err != nil { + return err + } + addr = s.mapFamily(addr, family) + // Always return right away in the non-blocking case. if !blocking { return syserr.TranslateNetstackError(s.Endpoint.Connect(addr)) @@ -655,6 +739,14 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo defer s.EventUnregister(&e) if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting { + if (s.family == unix.AF_INET || s.family == unix.AF_INET6) && s.skType == linux.SOCK_STREAM { + // TCP unlike UDP returns EADDRNOTAVAIL when it can't + // find an available local ephemeral port. + if err == tcpip.ErrNoPortAvailable { + return syserr.ErrAddressNotAvailable + } + } + return syserr.TranslateNetstackError(err) } @@ -670,10 +762,44 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { - addr, _, err := AddressAndFamily(s.family, sockaddr, true /* strict */) - if err != nil { - return err +func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { + if len(sockaddr) < 2 { + return syserr.ErrInvalidArgument + } + + family := usermem.ByteOrder.Uint16(sockaddr) + var addr tcpip.FullAddress + + // Bind for AF_PACKET requires only family, protocol and ifindex. + // In function AddressAndFamily, we check the address length which is + // not needed for AF_PACKET bind. + if family == linux.AF_PACKET { + var a linux.SockAddrLink + if len(sockaddr) < sockAddrLinkSize { + return syserr.ErrInvalidArgument + } + binary.Unmarshal(sockaddr[:sockAddrLinkSize], usermem.ByteOrder, &a) + + if a.Protocol != uint16(s.protocol) { + return syserr.ErrInvalidArgument + } + + addr = tcpip.FullAddress{ + NIC: tcpip.NICID(a.InterfaceIndex), + Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + } + } else { + var err *syserr.Error + addr, family, err = AddressAndFamily(sockaddr) + if err != nil { + return err + } + + if err = s.checkFamily(family, true /* exact */); err != nil { + return err + } + + addr = s.mapFamily(addr, family) } // Issue the bind request to the endpoint. @@ -682,13 +808,13 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { // Listen implements the linux syscall listen(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { return syserr.TranslateNetstackError(s.Endpoint.Listen(backlog)) } // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) { +func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.EventRegister(&e, waiter.EventIn) @@ -728,7 +854,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, if err != nil { return 0, nil, 0, err } - defer ns.DecRef() + defer ns.DecRef(t) if flags&linux.SOCK_NONBLOCK != 0 { flags := ns.Flags() @@ -774,7 +900,7 @@ func ConvertShutdown(how int) (tcpip.ShutdownFlags, *syserr.Error) { // Shutdown implements the linux syscall shutdown(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { f, err := ConvertShutdown(how) if err != nil { return err @@ -786,7 +912,7 @@ func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -796,25 +922,25 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil + return &val, nil } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { @@ -824,22 +950,30 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us return nil, syserr.ErrInvalidArgument } - info, err := netfilter.GetInfo(t, s.Endpoint, outPtr) + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr) if err != nil { return nil, err } - return info, nil + return &info, nil case linux.IPT_SO_GET_ENTRIES: if outLen < linux.SizeOfIPTGetEntries { return nil, syserr.ErrInvalidArgument } - entries, err := netfilter.GetEntries(t, s.Endpoint, outPtr, outLen) + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen) if err != nil { return nil, err } - return entries, nil + return &entries, nil } } @@ -849,7 +983,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -874,8 +1008,15 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, return nil, syserr.ErrProtocolNotAvailable } +func boolToInt32(v bool) int32 { + if v { + return 1 + } + return 0 +} + // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_ERROR: @@ -886,9 +1027,12 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family // Get the last error and convert it. err := ep.GetSockOpt(tcpip.ErrorOption{}) if err == nil { - return int32(0), nil + optP := primitive.Int32(0) + return &optP, nil } - return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil + + optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number()) + return &optP, nil case linux.SO_PEERCRED: if family != linux.AF_UNIX || outLen < syscall.SizeofUcred { @@ -896,23 +1040,25 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family } tcred := t.Credentials() - return syscall.Ucred{ - Pid: int32(t.ThreadGroup().ID()), - Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), - Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), - }, nil + creds := linux.ControlMessageCredentials{ + PID: int32(t.ThreadGroup().ID()), + UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), + GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), + } + return &creds, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.PasscredOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.PasscredOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -928,7 +1074,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { @@ -944,74 +1091,93 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_REUSEADDR: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.ReuseAddressOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.ReusePortOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.ReusePortOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - if len(v) == 0 { - return []byte{}, nil + if v == 0 { + var b primitive.ByteSlice + return &b, nil } if outLen < linux.IFNAMSIZ { return nil, syserr.ErrInvalidArgument } - return append([]byte(v), 0), nil + s := t.NetworkContext() + if s == nil { + return nil, syserr.ErrNoDevice + } + nic, ok := s.Interfaces()[int32(v)] + if !ok { + // The NICID no longer indicates a valid interface, probably because that + // interface was removed. + return nil, syserr.ErrUnknownDevice + } + + name := primitive.ByteSlice(append([]byte(nic.Name), 0)) + return &name, nil case linux.SO_BROADCAST: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.BroadcastOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.BroadcastOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.KeepaliveEnabledOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { return nil, syserr.ErrInvalidArgument } - return linux.Linger{}, nil + + linger := linux.Linger{} + return &linger, nil case linux.SO_SNDTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1019,7 +1185,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.SendTimeout()), nil + sendTimeout := linux.NsecToTimeval(s.SendTimeout()) + return &sendTimeout, nil case linux.SO_RCVTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1027,7 +1194,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.RecvTimeout()), nil + recvTimeout := linux.NsecToTimeval(s.RecvTimeout()) + return &recvTimeout, nil case linux.SO_OOBINLINE: if outLen < sizeOfInt32 { @@ -1039,7 +1207,20 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil + + case linux.SO_NO_CHECK: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.NoChecksumOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -1048,58 +1229,58 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptInt(tcpip.DelayOption) + v, err := ep.GetSockOptBool(tcpip.DelayOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } - if v == 0 { - return int32(1), nil - } - return int32(0), nil + vP := primitive.Int32(boolToInt32(!v)) + return &vP, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.CorkOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.CorkOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.QuickAckOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.QuickAckOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.MaxSegOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.MaxSegOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_KEEPIDLE: if outLen < sizeOfInt32 { @@ -1110,8 +1291,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveIdle, nil case linux.TCP_KEEPINTVL: if outLen < sizeOfInt32 { @@ -1122,8 +1303,32 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } + keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveInterval, nil + + case linux.TCP_KEEPCNT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptInt(tcpip.KeepaliveCountOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + vP := primitive.Int32(v) + return &vP, nil - return int32(time.Duration(v) / time.Second), nil + case linux.TCP_USER_TIMEOUT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPUserTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond) + return &tcpUserTimeout, nil case linux.TCP_INFO: var v tcpip.TCPInfoOption @@ -1136,12 +1341,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa info := linux.TCPInfo{} // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &info) - if len(ib) > outLen { - ib = ib[:outLen] + buf := t.CopyScratchBuffer(info.SizeBytes()) + info.MarshalUnsafe(buf) + if len(buf) > outLen { + buf = buf[:outLen] } - - return ib, nil + bufP := primitive.ByteSlice(buf) + return &bufP, nil case linux.TCP_CC_INFO, linux.TCP_NOTSENT_LOWAT, @@ -1171,8 +1377,59 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } b := make([]byte, toCopy) copy(b, v) - return b, nil + bP := primitive.ByteSlice(b) + return &bP, nil + + case linux.TCP_LINGER2: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPLingerTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + lingerTimeout := primitive.Int32(time.Duration(v) / time.Second) + return &lingerTimeout, nil + + case linux.TCP_DEFER_ACCEPT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPDeferAcceptOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second) + return &tcpDeferAccept, nil + + case linux.TCP_SYNCNT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptInt(tcpip.TCPSynCountOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + vP := primitive.Int32(v) + return &vP, nil + + case linux.TCP_WINDOW_CLAMP: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptInt(tcpip.TCPWindowClampOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + vP := primitive.Int32(v) + return &vP, nil default: emitUnimplementedEventTCP(t, name) } @@ -1180,19 +1437,20 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. -func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.V6OnlyOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.V6OnlyOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1200,21 +1458,41 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf case linux.IPV6_TCLASS: // Length handling for parity with Linux. if outLen == 0 { - return make([]byte, 0), nil + var b primitive.ByteSlice + return &b, nil } - var v tcpip.IPv6TrafficClassOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - uintv := uint32(v) + uintv := primitive.Uint32(v) // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &uintv) + ib := t.CopyScratchBuffer(uintv.SizeBytes()) + uintv.MarshalUnsafe(ib) // Handle cases where outLen is lesser than sizeOfInt32. if len(ib) > outLen { ib = ib[:outLen] } - return ib, nil + ibP := primitive.ByteSlice(ib) + return &ibP, nil + + case linux.IPV6_RECVTCLASS: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil + + case linux.SO_ORIGINAL_DST: + // TODO(gvisor.dev/issue/170): ip6tables. + return nil, syserr.ErrInvalidArgument default: emitUnimplementedEventIPv6(t, name) @@ -1223,36 +1501,38 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) { +func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.TTLOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.TTLOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } // Fill in the default value, if needed. - if v == 0 { - v = DefaultTTL + vP := primitive.Int32(v) + if vP == 0 { + vP = DefaultTTL } - return int32(v), nil + return &vP, nil case linux.IP_MULTICAST_TTL: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.MulticastTTLOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.MulticastTTLOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_MULTICAST_IF: if outLen < len(linux.InetAddr{}) { @@ -1266,36 +1546,76 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) - return a.(*linux.SockAddrInet).Addr, nil + return &a.(*linux.SockAddrInet).Addr, nil case linux.IP_MULTICAST_LOOP: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.MulticastLoopOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - if v { - return int32(1), nil - } - return int32(0), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_TOS: // Length handling for parity with Linux. if outLen == 0 { - return []byte(nil), nil + var b primitive.ByteSlice + return &b, nil } - var v tcpip.IPv4TOSOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } if outLen < sizeOfInt32 { - return uint8(v), nil + vP := primitive.Uint8(v) + return &vP, nil + } + vP := primitive.Int32(v) + return &vP, nil + + case linux.IP_RECVTOS: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - return int32(v), nil + + v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil + + case linux.IP_PKTINFO: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil + + case linux.SO_ORIGINAL_DST: + if outLen < int(binary.Size(linux.SockAddrInet{})) { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.OriginalDestinationOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) + return a.(*linux.SockAddrInet), nil default: emitUnimplementedEventIP(t, name) @@ -1330,12 +1650,32 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa return nil } + if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { + switch name { + case linux.IPT_SO_SET_REPLACE: + if len(optVal) < linux.SizeOfIPTReplace { + return syserr.ErrInvalidArgument + } + + stack := inet.StackFromContext(t) + if stack == nil { + return syserr.ErrNoDevice + } + // Stack must be a netstack stack. + return netfilter.SetEntries(stack.(*Stack).Stack, optVal) + + case linux.IPT_SO_SET_ADD_COUNTERS: + // TODO(gvisor.dev/issue/170): Counter support. + return nil + } + } + return SetSockOpt(t, s, s.Endpoint, level, name, optVal) } // SetSockOpt can be used to implement the linux syscall setsockopt(2) for // sockets backed by a commonEndpoint. -func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error { +func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error { switch level { case linux.SOL_SOCKET: return setSockOptSocket(t, s, ep, name, optVal) @@ -1362,7 +1702,7 @@ func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, n } // setSockOptSocket implements SetSockOpt when level is SOL_SOCKET. -func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int, optVal []byte) *syserr.Error { +func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { switch name { case linux.SO_SNDBUF: if len(optVal) < sizeOfInt32 { @@ -1386,7 +1726,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReuseAddressOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0)) case linux.SO_REUSEPORT: if len(optVal) < sizeOfInt32 { @@ -1394,14 +1734,27 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0)) case linux.SO_BINDTODEVICE: n := bytes.IndexByte(optVal, 0) if n == -1 { n = len(optVal) } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n]))) + name := string(optVal[:n]) + if name == "" { + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0))) + } + s := t.NetworkContext() + if s == nil { + return syserr.ErrNoDevice + } + for nicID, nic := range s.Interfaces() { + if nic.Name == name { + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID))) + } + } + return syserr.ErrUnknownDevice case linux.SO_BROADCAST: if len(optVal) < sizeOfInt32 { @@ -1409,7 +1762,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BroadcastOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.BroadcastOption, v != 0)) case linux.SO_PASSCRED: if len(optVal) < sizeOfInt32 { @@ -1417,7 +1770,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0)) case linux.SO_KEEPALIVE: if len(optVal) < sizeOfInt32 { @@ -1425,7 +1778,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveEnabledOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0)) case linux.SO_SNDTIMEO: if len(optVal) < linux.SizeOfTimeval { @@ -1466,6 +1819,14 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.OutOfBandInlineOption(v))) + case linux.SO_NO_CHECK: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0)) + case linux.SO_LINGER: if len(optVal) < linux.SizeOfLinger { return syserr.ErrInvalidArgument @@ -1480,6 +1841,11 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i return nil + case linux.SO_DETACH_FILTER: + // optval is ignored. + var v tcpip.SocketDetachFilterOption + return syserr.TranslateNetstackError(ep.SetSockOpt(v)) + default: socket.SetSockOptEmitUnimplementedEvent(t, name) } @@ -1497,11 +1863,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - var o int - if v == 0 { - o = 1 - } - return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o)) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0)) case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -1509,7 +1871,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.CorkOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0)) case linux.TCP_QUICKACK: if len(optVal) < sizeOfInt32 { @@ -1517,7 +1879,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0)) case linux.TCP_MAXSEG: if len(optVal) < sizeOfInt32 { @@ -1525,7 +1887,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MaxSegOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MaxSegOption, int(v))) case linux.TCP_KEEPIDLE: if len(optVal) < sizeOfInt32 { @@ -1549,6 +1911,28 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + case linux.TCP_KEEPCNT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + if v < 1 || v > linux.MAX_TCP_KEEPCNT { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.KeepaliveCountOption, int(v))) + + case linux.TCP_USER_TIMEOUT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < 0 { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v)))) + case linux.TCP_CONGESTION: v := tcpip.CongestionControlOption(optVal) if err := ep.SetSockOpt(v); err != nil { @@ -1556,6 +1940,40 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return nil + case linux.TCP_LINGER2: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)))) + + case linux.TCP_DEFER_ACCEPT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < 0 { + v = 0 + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v)))) + + case linux.TCP_SYNCNT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := usermem.ByteOrder.Uint32(optVal) + + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPSynCountOption, int(v))) + + case linux.TCP_WINDOW_CLAMP: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := usermem.ByteOrder.Uint32(optVal) + + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPWindowClampOption, int(v))) + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) @@ -1576,13 +1994,14 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.V6OnlyOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0)) case linux.IPV6_ADD_MEMBERSHIP, linux.IPV6_DROP_MEMBERSHIP, linux.IPV6_IPSEC_POLICY, linux.IPV6_JOIN_ANYCAST, linux.IPV6_LEAVE_ANYCAST, + // TODO(b/148887420): Add support for IPV6_PKTINFO. linux.IPV6_PKTINFO, linux.IPV6_ROUTER_ALERT, linux.IPV6_XFRM_POLICY, @@ -1606,7 +2025,15 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) if v == -1 { v = 0 } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, int(v))) + + case linux.IPV6_RECVTCLASS: + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0)) default: emitUnimplementedEventIPv6(t, name) @@ -1683,7 +2110,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s if v < 0 || v > 255 { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastTTLOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MulticastTTLOption, int(v))) case linux.IP_ADD_MEMBERSHIP: req, err := copyInMulticastRequest(optVal, false /* allowAddr */) @@ -1730,9 +2157,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s return err } - return syserr.TranslateNetstackError(ep.SetSockOpt( - tcpip.MulticastLoopOption(v != 0), - )) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0)) case linux.MCAST_JOIN_GROUP: // FIXME(b/124219304): Implement MCAST_JOIN_GROUP. @@ -1751,7 +2176,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } else if v < 1 || v > 255 { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TTLOption, int(v))) case linux.IP_TOS: if len(optVal) == 0 { @@ -1761,7 +2186,34 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv4TOSOption, int(v))) + + case linux.IP_RECVTOS: + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0)) + + case linux.IP_PKTINFO: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + + case linux.IP_HDRINCL: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, @@ -1769,7 +2221,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_CHECKSUM, linux.IP_DROP_SOURCE_MEMBERSHIP, linux.IP_FREEBIND, - linux.IP_HDRINCL, linux.IP_IPSEC_POLICY, linux.IP_MINTTL, linux.IP_MSFILTER, @@ -1778,12 +2229,10 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_NODEFRAG, linux.IP_OPTIONS, linux.IP_PASSSEC, - linux.IP_PKTINFO, linux.IP_RECVERR, linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, linux.IP_RECVORIGDSTADDR, - linux.IP_RECVTOS, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -1811,30 +2260,20 @@ func emitUnimplementedEventTCP(t *kernel.Task, name int) { switch name { case linux.TCP_CONGESTION, linux.TCP_CORK, - linux.TCP_DEFER_ACCEPT, linux.TCP_FASTOPEN, linux.TCP_FASTOPEN_CONNECT, linux.TCP_FASTOPEN_KEY, linux.TCP_FASTOPEN_NO_COOKIE, - linux.TCP_KEEPCNT, - linux.TCP_KEEPIDLE, - linux.TCP_KEEPINTVL, - linux.TCP_LINGER2, - linux.TCP_MAXSEG, linux.TCP_QUEUE_SEQ, - linux.TCP_QUICKACK, linux.TCP_REPAIR, linux.TCP_REPAIR_QUEUE, linux.TCP_REPAIR_WINDOW, linux.TCP_SAVED_SYN, linux.TCP_SAVE_SYN, - linux.TCP_SYNCNT, linux.TCP_THIN_DUPACK, linux.TCP_THIN_LINEAR_TIMEOUTS, linux.TCP_TIMESTAMP, - linux.TCP_ULP, - linux.TCP_USER_TIMEOUT, - linux.TCP_WINDOW_CLAMP: + linux.TCP_ULP: t.Kernel().EmitUnimplementedEvent(t) } @@ -1876,7 +2315,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_RECVPATHMTU, linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, - linux.IPV6_RECVTCLASS, linux.IPV6_RTHDR, linux.IPV6_RTHDRDSTOPTS, linux.IPV6_TCLASS, @@ -1981,8 +2419,8 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) case linux.AF_INET6: var out linux.SockAddrInet6 - if len(addr.Addr) == 4 { - // Copy address is v4-mapped format. + if len(addr.Addr) == header.IPv4AddressSize { + // Copy address in v4-mapped format. copy(out.Addr[12:], addr.Addr) out.Addr[10] = 0xff out.Addr[11] = 0xff @@ -1997,7 +2435,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) return &out, uint32(sockAddrInet6Size) case linux.AF_PACKET: - // TODO(b/129292371): Return protocol too. + // TODO(gvisor.dev/issue/173): Return protocol too. var out linux.SockAddrLink out.Family = linux.AF_PACKET out.InterfaceIndex = int32(addr.NIC) @@ -2012,7 +2450,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -2024,7 +2462,7 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -2039,16 +2477,21 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, // caller. // // Precondition: s.readMu must be locked. -func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) { +func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) { var err *syserr.Error var copied int // Copy as many views as possible into the user-provided buffer. - for dst.NumBytes() != 0 { + for { + // Always do at least one fetchReadView, even if the number of bytes to + // read is 0. err = s.fetchReadView() if err != nil { break } + if dst.NumBytes() == 0 { + break + } var n int var e error @@ -2066,6 +2509,10 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq } copied += n s.readView.TrimFront(n) + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) + } + dst = dst.DropFirst(n) if e != nil { err = syserr.FromError(e) @@ -2082,7 +2529,7 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq return 0, err } -func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { +func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) { if !s.sockOptInq { return } @@ -2094,10 +2541,27 @@ func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) } +func toLinuxPacketType(pktType tcpip.PacketType) uint8 { + switch pktType { + case tcpip.PacketHost: + return linux.PACKET_HOST + case tcpip.PacketOtherHost: + return linux.PACKET_OTHERHOST + case tcpip.PacketOutgoing: + return linux.PACKET_OUTGOING + case tcpip.PacketBroadcast: + return linux.PACKET_BROADCAST + case tcpip.PacketMulticast: + return linux.PACKET_MULTICAST + default: + panic(fmt.Sprintf("unknown packet type: %d", pktType)) + } +} + // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. -func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { isPacket := s.isPacketBased() // Fast path for regular reads from stream (e.g., TCP) endpoints. Note @@ -2112,9 +2576,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe // caller-supplied buffer. s.readMu.Lock() n, err := s.coalescingRead(ctx, dst, trunc) - s.readMu.Unlock() cmsg := s.controlMessages() s.fillCmsgInq(&cmsg) + s.readMu.Unlock() return n, 0, nil, 0, cmsg, err } @@ -2149,6 +2613,11 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe var addrLen uint32 if isPacket && senderRequested { addr, addrLen = ConvertAddress(s.family, s.sender) + switch v := addr.(type) { + case *linux.SockAddrLink: + v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) + } } if peek { @@ -2188,6 +2657,10 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe s.readView.TrimFront(int(n)) } + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) + } + var flags int if msgLen > int(n) { flags |= linux.MSG_TRUNC @@ -2202,15 +2675,26 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe return n, flags, addr, addrLen, cmsg, syserr.FromError(err) } -func (s *SocketOperations) controlMessages() socket.ControlMessages { - return socket.ControlMessages{IP: tcpip.ControlMessages{HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, Timestamp: s.readCM.Timestamp}} +func (s *socketOpsCommon) controlMessages() socket.ControlMessages { + return socket.ControlMessages{ + IP: tcpip.ControlMessages{ + HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: s.readCM.Timestamp, + HasTOS: s.readCM.HasTOS, + TOS: s.readCM.TOS, + HasTClass: s.readCM.HasTClass, + TClass: s.readCM.TClass, + HasIPPacketInfo: s.readCM.HasIPPacketInfo, + PacketInfo: s.readCM.PacketInfo, + }, + } } // updateTimestamp sets the timestamp for SIOCGSTAMP. It should be called after // successfully writing packet data out to userspace. // // Precondition: s.readMu must be locked. -func (s *SocketOperations) updateTimestamp() { +func (s *socketOpsCommon) updateTimestamp() { // Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled. if !s.sockOptTimestamp { s.timestampValid = true @@ -2220,7 +2704,7 @@ func (s *SocketOperations) updateTimestamp() { // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -2288,7 +2772,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // SendMsg implements the linux syscall sendmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { // Reject Unix control messages. if !controlMessages.Unix.Empty() { return 0, syserr.ErrInvalidArgument @@ -2296,10 +2780,14 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, _, err := AddressAndFamily(s.family, to, true /* strict */) + addrBuf, family, err := AddressAndFamily(to) if err != nil { return 0, err } + if err := s.checkFamily(family, false /* exact */); err != nil { + return 0, err + } + addrBuf = s.mapFamily(addrBuf, family) addr = &addrBuf } @@ -2360,11 +2848,20 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] // Ioctl implements fs.FileOperations.Ioctl. func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return s.socketOpsCommon.ioctl(ctx, io, args) +} + +func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("ioctl(2) may only be called from a task goroutine") + } + // SIOCGSTAMP is implemented by netstack rather than all commonEndpoint // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. switch args[1].Int() { - case syscall.SIOCGSTAMP: + case linux.SIOCGSTAMP: s.readMu.Lock() defer s.readMu.Unlock() if !s.timestampValid { @@ -2372,9 +2869,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, } tv := linux.NsecToTimeval(s.timestampNS) - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &tv, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := tv.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCINQ: @@ -2384,16 +2879,17 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, } // Add bytes removed from the endpoint but not yet sent to the caller. + s.readMu.Lock() v += len(s.readView) + s.readMu.Unlock() if v > math.MaxInt32 { v = math.MaxInt32 } - // Copy result to user-space. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + // Copy result to userspace. + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err } @@ -2402,52 +2898,49 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // Ioctl performs a socket ioctl. func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("ioctl(2) may only be called from a task goroutine") + } + switch arg := int(args[1].Int()); arg { - case syscall.SIOCGIFFLAGS, - syscall.SIOCGIFADDR, - syscall.SIOCGIFBRDADDR, - syscall.SIOCGIFDSTADDR, - syscall.SIOCGIFHWADDR, - syscall.SIOCGIFINDEX, - syscall.SIOCGIFMAP, - syscall.SIOCGIFMETRIC, - syscall.SIOCGIFMTU, - syscall.SIOCGIFNAME, - syscall.SIOCGIFNETMASK, - syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFFLAGS, + linux.SIOCGIFADDR, + linux.SIOCGIFBRDADDR, + linux.SIOCGIFDSTADDR, + linux.SIOCGIFHWADDR, + linux.SIOCGIFINDEX, + linux.SIOCGIFMAP, + linux.SIOCGIFMETRIC, + linux.SIOCGIFMTU, + linux.SIOCGIFNAME, + linux.SIOCGIFNETMASK, + linux.SIOCGIFTXQLEN, + linux.SIOCETHTOOL: var ifr linux.IFReq - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifr.CopyIn(t, args[2].Pointer()); err != nil { return 0, err } if err := interfaceIoctl(ctx, io, arg, &ifr); err != nil { return 0, err.ToError() } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := ifr.CopyOut(t, args[2].Pointer()) return 0, err - case syscall.SIOCGIFCONF: + case linux.SIOCGIFCONF: // Return a list of interface addresses or the buffer size // necessary to hold the list. var ifc linux.IFConf - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifc.CopyIn(t, args[2].Pointer()); err != nil { return 0, err } - if err := ifconfIoctl(ctx, io, &ifc); err != nil { + if err := ifconfIoctl(ctx, t, io, &ifc); err != nil { return 0, err } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{ - AddressSpaceActive: true, - }) - + _, err := ifc.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCINQ: @@ -2459,10 +2952,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc if v > math.MaxInt32 { v = math.MaxInt32 } - // Copy result to user-space. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + // Copy result to userspace. + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCOUTQ: @@ -2475,10 +2967,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc v = math.MaxInt32 } - // Copy result to user-space. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + // Copy result to userspace. + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG: @@ -2504,7 +2995,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // SIOCGIFNAME uses ifr.ifr_ifindex rather than ifr.ifr_name to // identify a device. - if arg == syscall.SIOCGIFNAME { + if arg == linux.SIOCGIFNAME { // Gets the name of the interface given the interface index // stored in ifr_ifindex. index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4])) @@ -2527,21 +3018,28 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } switch arg { - case syscall.SIOCGIFINDEX: + case linux.SIOCGIFINDEX: // Copy out the index to the data. usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index)) - case syscall.SIOCGIFHWADDR: + case linux.SIOCGIFHWADDR: // Copy the hardware address out. - ifr.Data[0] = 6 // IEEE802.2 arp type. - ifr.Data[1] = 0 + // + // Refer: https://linux.die.net/man/7/netdevice + // SIOCGIFHWADDR, SIOCSIFHWADDR + // + // Get or set the hardware address of a device using + // ifr_hwaddr. The hardware address is specified in a struct + // sockaddr. sa_family contains the ARPHRD_* device type, + // sa_data the L2 hardware address starting from byte 0. Setting + // the hardware address is a privileged operation. + usermem.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType) n := copy(ifr.Data[2:], iface.Addr) for i := 2 + n; i < len(ifr.Data); i++ { ifr.Data[i] = 0 // Clear padding. } - usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n)) - case syscall.SIOCGIFFLAGS: + case linux.SIOCGIFFLAGS: f, err := interfaceStatusFlags(stack, iface.Name) if err != nil { return err @@ -2550,7 +3048,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // matches Linux behavior. usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f)) - case syscall.SIOCGIFADDR: + case linux.SIOCGIFADDR: // Copy the IPv4 address out. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2561,32 +3059,32 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } - case syscall.SIOCGIFMETRIC: + case linux.SIOCGIFMETRIC: // Gets the metric of the device. As per netdevice(7), this // always just sets ifr_metric to 0. usermem.ByteOrder.PutUint32(ifr.Data[:4], 0) - case syscall.SIOCGIFMTU: + case linux.SIOCGIFMTU: // Gets the MTU of the device. usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU) - case syscall.SIOCGIFMAP: + case linux.SIOCGIFMAP: // Gets the hardware parameters of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFTXQLEN: // Gets the transmit queue length of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFDSTADDR: + case linux.SIOCGIFDSTADDR: // Gets the destination address of a point-to-point device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFBRDADDR: + case linux.SIOCGIFBRDADDR: // Gets the broadcast address of a device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFNETMASK: + case linux.SIOCGIFNETMASK: // Gets the network mask of a device. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2603,6 +3101,14 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } + case linux.SIOCETHTOOL: + // Stubbed out for now, Ideally we should implement the required + // sub-commands for ETHTOOL + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/net/core/dev_ioctl.c + return syserr.ErrEndpointOperation + default: // Not a valid call. return syserr.ErrInvalidArgument @@ -2612,7 +3118,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } // ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl. -func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { +func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error { // If Ptr is NULL, return the necessary buffer size via Len. // Otherwise, write up to Len bytes starting at Ptr containing ifreq // structs. @@ -2649,9 +3155,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { // Copy the ifr to userspace. dst := uintptr(ifc.Ptr) + uintptr(ifc.Len) ifc.Len += int32(linux.SizeOfIFReq) - if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifr.CopyOut(t, usermem.Addr(dst)); err != nil { return err } } @@ -2697,7 +3201,7 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 { // State implements socket.Socket.State. State translates the internal state // returned by netstack to values defined by Linux. -func (s *SocketOperations) State() uint32 { +func (s *socketOpsCommon) State() uint32 { if s.family != linux.AF_INET && s.family != linux.AF_INET6 { // States not implemented for this socket's family. return 0 @@ -2757,6 +3261,8 @@ func (s *SocketOperations) State() uint32 { } // Type implements socket.Socket.Type. -func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { +func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) { return s.family, s.skType, s.protocol } + +// LINT.ThenChange(./netstack_vfs2.go) |