diff options
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r-- | pkg/sentry/socket/netlink/BUILD | 1 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/port/BUILD | 4 | ||||
-rw-r--r-- | pkg/sentry/socket/netlink/socket.go | 22 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/BUILD (renamed from pkg/sentry/socket/epsocket/BUILD) | 6 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/device.go (renamed from pkg/sentry/socket/epsocket/device.go) | 6 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/netstack.go (renamed from pkg/sentry/socket/epsocket/epsocket.go) | 360 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/provider.go (renamed from pkg/sentry/socket/epsocket/provider.go) | 4 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/save_restore.go (renamed from pkg/sentry/socket/epsocket/save_restore.go) | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/netstack/stack.go (renamed from pkg/sentry/socket/epsocket/stack.go) | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/rpcinet/BUILD | 9 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/BUILD | 2 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/queue.go | 3 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/transport/unix.go | 82 | ||||
-rw-r--r-- | pkg/sentry/socket/unix/unix.go | 24 |
14 files changed, 418 insertions, 109 deletions
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index 45ebb2a0e..7da68384e 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -21,6 +21,7 @@ go_library( "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", + "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/netlink/port", "//pkg/sentry/socket/unix", diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD index 9e2e12799..445080aa4 100644 --- a/pkg/sentry/socket/netlink/port/BUILD +++ b/pkg/sentry/socket/netlink/port/BUILD @@ -1,6 +1,8 @@ +load("@io_bazel_rules_go//go:def.bzl", "go_test") + package(licenses = ["notice"]) -load("//tools/go_stateify:defs.bzl", "go_library", "go_test") +load("//tools/go_stateify:defs.bzl", "go_library") go_library( name = "port", diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index d0aab293d..b2732ca29 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -28,6 +28,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" + "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port" "gvisor.dev/gvisor/pkg/sentry/socket/unix" @@ -416,6 +417,24 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have Peek: flags&linux.MSG_PEEK != 0, } + // If MSG_TRUNC is set with a zero byte destination then we still need + // to read the message and discard it, or in the case where MSG_PEEK is + // set, leave it be. In both cases the full message length must be + // returned. However, the memory manager for the destination will not read + // the endpoint if the destination is zero length. + // + // In order for the endpoint to be read when the destination size is zero, + // we must cause a read of the endpoint by using a separate fake zero + // length block sequence and calling the EndpointReader directly. + if trunc && dst.Addrs.NumBytes() == 0 { + // Perform a read to a zero byte block sequence. We can ignore the + // original destination since it was zero bytes. The length returned by + // ReadToBlocks is ignored and we return the full message length to comply + // with MSG_TRUNC. + _, err := r.ReadToBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(make([]byte, 0)))) + return int(r.MsgSize), linux.MSG_TRUNC, from, fromLen, socket.ControlMessages{}, syserr.FromError(err) + } + if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { var mflags int if n < int64(r.MsgSize) { @@ -499,6 +518,9 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error PortID: uint32(ms.PortID), }) + // Add the dump_done_errno payload. + m.Put(int64(0)) + _, notify, err := s.connection.Send([][]byte{m.Finalize()}, transport.ControlMessages{}, tcpip.FullAddress{}) if err != nil && err != syserr.ErrWouldBlock { return err diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/netstack/BUILD index e927821e1..60523f79a 100644 --- a/pkg/sentry/socket/epsocket/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -3,15 +3,15 @@ package(licenses = ["notice"]) load("//tools/go_stateify:defs.bzl", "go_library") go_library( - name = "epsocket", + name = "netstack", srcs = [ "device.go", - "epsocket.go", + "netstack.go", "provider.go", "save_restore.go", "stack.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/epsocket", + importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netstack", visibility = [ "//pkg/sentry:internal", ], diff --git a/pkg/sentry/socket/epsocket/device.go b/pkg/sentry/socket/netstack/device.go index 85484d5b1..fbeb89fb8 100644 --- a/pkg/sentry/socket/epsocket/device.go +++ b/pkg/sentry/socket/netstack/device.go @@ -12,9 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. -package epsocket +package netstack import "gvisor.dev/gvisor/pkg/sentry/device" -// epsocketDevice is the endpoint socket virtual device. -var epsocketDevice = device.NewAnonDevice() +// netstackDevice is the endpoint socket virtual device. +var netstackDevice = device.NewAnonDevice() diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/netstack/netstack.go index 635042263..0ae573b45 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package epsocket provides an implementation of the socket.Socket interface +// Package netstack provides an implementation of the socket.Socket interface // that is backed by a tcpip.Endpoint. // // It does not depend on any particular endpoint implementation, and thus can @@ -22,17 +22,20 @@ // Lock ordering: netstack => mm: ioSequencePayload copies user memory inside // tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during // this operation. -package epsocket +package netstack import ( "bytes" + "io" "math" + "reflect" "sync" "syscall" "time" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" @@ -52,6 +55,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -133,11 +137,13 @@ var Metrics = tcpip.Stats{ }, }, IP: tcpip.IPStats{ - PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), - InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), - PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), - PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), - OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), + PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), + InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), + PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), + PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), + OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), + MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."), + MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."), }, TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), @@ -151,6 +157,7 @@ var Metrics = tcpip.Stats{ ValidSegmentsReceived: mustCreateMetric("/netstack/tcp/valid_segments_received", "Number of TCP segments received that the transport layer successfully parsed."), InvalidSegmentsReceived: mustCreateMetric("/netstack/tcp/invalid_segments_received", "Number of TCP segments received that the transport layer could not parse."), SegmentsSent: mustCreateMetric("/netstack/tcp/segments_sent", "Number of TCP segments sent."), + SegmentSendErrors: mustCreateMetric("/netstack/tcp/segment_send_errors", "Number of TCP segments failed to be sent."), ResetsSent: mustCreateMetric("/netstack/tcp/resets_sent", "Number of TCP resets sent."), ResetsReceived: mustCreateMetric("/netstack/tcp/resets_received", "Number of TCP resets received."), Retransmits: mustCreateMetric("/netstack/tcp/retransmits", "Number of TCP segments retransmitted."), @@ -166,13 +173,18 @@ var Metrics = tcpip.Stats{ UnknownPortErrors: mustCreateMetric("/netstack/udp/unknown_port_errors", "Number of incoming UDP datagrams dropped because they did not have a known destination port."), ReceiveBufferErrors: mustCreateMetric("/netstack/udp/receive_buffer_errors", "Number of incoming UDP datagrams dropped due to the receiving buffer being in an invalid state."), MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."), - PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent via sendUDP."), + PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."), + PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."), }, } +// DefaultTTL is linux's default TTL. All network protocols in all stacks used +// with this package must have this value set as their default TTL. +const DefaultTTL = 64 + const sizeOfInt32 int = 4 -var errStackType = syserr.New("expected but did not receive an epsocket.Stack", linux.EINVAL) +var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL) // ntohs converts a 16-bit number from network byte order to host byte order. It // assumes that the host is little endian. @@ -205,6 +217,10 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(interface{}) *tcpip.Error + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and + // transport.Endpoint.SetSockOptInt. + SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error @@ -224,7 +240,6 @@ type SocketOperations struct { fsutil.FileNoopFlush `state:"nosave"` fsutil.FileNoFsync `state:"nosave"` fsutil.FileNoMMap `state:"nosave"` - fsutil.FileNoSplice `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout *waiter.Queue @@ -255,8 +270,8 @@ type SocketOperations struct { // valid when timestampValid is true. It is protected by readMu. timestampNS int64 - // sockOptInq corresponds to TCP_INQ. It is implemented on the epsocket - // level, because it takes into account data from readView. + // sockOptInq corresponds to TCP_INQ. It is implemented at this level + // because it takes into account data from readView. sockOptInq bool } @@ -268,7 +283,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue } } - dirent := socket.NewDirent(t, epsocketDevice) + dirent := socket.NewDirent(t, netstackDevice) defer dirent.DecRef() return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, &SocketOperations{ Queue: queue, @@ -409,17 +424,60 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS return int64(n), nil } -// ioSequencePayload implements tcpip.Payload. It copies user memory bytes on demand -// based on the requested size. +// WriteTo implements fs.FileOperations.WriteTo. +func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { + s.readMu.Lock() + + // Copy as much data as possible. + done := int64(0) + for count > 0 { + // This may return a blocking error. + if err := s.fetchReadView(); err != nil { + s.readMu.Unlock() + return done, err.ToError() + } + + // Write to the underlying file. + n, err := dst.Write(s.readView) + done += int64(n) + count -= int64(n) + if dup { + // That's all we support for dup. This is generally + // supported by any Linux system calls, but the + // expectation is that now a caller will call read to + // actually remove these bytes from the socket. + break + } + + // Drop that part of the view. + s.readView.TrimFront(n) + if err != nil { + s.readMu.Unlock() + return done, err + } + } + + s.readMu.Unlock() + return done, nil +} + +// ioSequencePayload implements tcpip.Payload. +// +// t copies user memory bytes on demand based on the requested size. type ioSequencePayload struct { ctx context.Context src usermem.IOSequence } -// Get implements tcpip.Payload. -func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) { - if size > i.Size() { - size = i.Size() +// FullPayload implements tcpip.Payloader.FullPayload +func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) { + return i.Payload(int(i.src.NumBytes())) +} + +// Payload implements tcpip.Payloader.Payload. +func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) { + if max := int(i.src.NumBytes()); size > max { + size = max } v := buffer.NewView(size) if _, err := i.src.CopyIn(i.ctx, v); err != nil { @@ -428,11 +486,6 @@ func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) { return v, nil } -// Size implements tcpip.Payload. -func (i *ioSequencePayload) Size() int { - return int(i.src.NumBytes()) -} - // DropFirst drops the first n bytes from underlying src. func (i *ioSequencePayload) DropFirst(n int) { i.src = i.src.DropFirst(int(n)) @@ -466,6 +519,78 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO return int64(n), nil } +// readerPayload implements tcpip.Payloader. +// +// It allocates a view and reads from a reader on-demand, based on available +// capacity in the endpoint. +type readerPayload struct { + ctx context.Context + r io.Reader + count int64 + err error +} + +// FullPayload implements tcpip.Payloader.FullPayload. +func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) { + return r.Payload(int(r.count)) +} + +// Payload implements tcpip.Payloader.Payload. +func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) { + if size > int(r.count) { + size = int(r.count) + } + v := buffer.NewView(size) + n, err := r.r.Read(v) + if n > 0 { + // We ignore the error here. It may re-occur on subsequent + // reads, but for now we can enqueue some amount of data. + r.count -= int64(n) + return v[:n], nil + } + if err == syserror.ErrWouldBlock { + return nil, tcpip.ErrWouldBlock + } else if err != nil { + r.err = err // Save for propation. + return nil, tcpip.ErrBadAddress + } + + // There is no data and no error. Return an error, which will propagate + // r.err, which will be nil. This is the desired result: (0, nil). + return nil, tcpip.ErrBadAddress +} + +// ReadFrom implements fs.FileOperations.ReadFrom. +func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { + f := &readerPayload{ctx: ctx, r: r, count: count} + n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{ + // Reads may be destructive but should be very fast, + // so we can't release the lock while copying data. + Atomic: true, + }) + if err == tcpip.ErrWouldBlock { + return 0, syserror.ErrWouldBlock + } + + if resCh != nil { + t := ctx.(*kernel.Task) + if err := t.Block(resCh); err != nil { + return 0, syserr.FromError(err).ToError() + } + + n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{ + Atomic: true, // See above. + }) + } + if err == tcpip.ErrWouldBlock { + return n, syserror.ErrWouldBlock + } else if err != nil { + return int64(n), f.err // Propagate error. + } + + return int64(n), nil +} + // Readiness returns a mask of ready events for socket s. func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { r := s.Endpoint.Readiness(mask) @@ -643,7 +768,7 @@ func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { // tcpip.Endpoint. func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is - // implemented specifically for epsocket.SocketOperations rather than + // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket // options where the implementation is not shared, as unix sockets need // their own support for SO_TIMESTAMP. @@ -716,7 +841,7 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, return getSockOptIPv6(t, ep, name, outLen) case linux.SOL_IP: - return getSockOptIP(t, ep, name, outLen) + return getSockOptIP(t, ep, name, outLen, family) case linux.SOL_UDP, linux.SOL_ICMPV6, @@ -774,8 +899,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var size tcpip.SendBufferSizeOption - if err := ep.GetSockOpt(&size); err != nil { + size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -790,8 +915,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var size tcpip.ReceiveBufferSizeOption - if err := ep.GetSockOpt(&size); err != nil { + size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -825,6 +950,19 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return int32(v), nil + case linux.SO_BINDTODEVICE: + var v tcpip.BindToDeviceOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + if len(v) == 0 { + return []byte{}, nil + } + if outLen < linux.IFNAMSIZ { + return nil, syserr.ErrInvalidArgument + } + return append([]byte(v), 0), nil + case linux.SO_BROADCAST: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1039,6 +1177,25 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_TCLASS: + // Length handling for parity with Linux. + if outLen == 0 { + return make([]byte, 0), nil + } + var v tcpip.IPv6TrafficClassOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + uintv := uint32(v) + // Linux truncates the output binary to outLen. + ib := binary.Marshal(nil, usermem.ByteOrder, &uintv) + // Handle cases where outLen is lesser than sizeOfInt32. + if len(ib) > outLen { + ib = ib[:outLen] + } + return ib, nil + default: emitUnimplementedEventIPv6(t, name) } @@ -1046,8 +1203,25 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) { switch name { + case linux.IP_TTL: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TTLOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + // Fill in the default value, if needed. + if v == 0 { + v = DefaultTTL + } + + return int32(v), nil + case linux.IP_MULTICAST_TTL: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1089,6 +1263,20 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac } return int32(0), nil + case linux.IP_TOS: + // Length handling for parity with Linux. + if outLen == 0 { + return []byte(nil), nil + } + var v tcpip.IPv4TOSOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + if outLen < sizeOfInt32 { + return uint8(v), nil + } + return int32(v), nil + default: emitUnimplementedEventIP(t, name) } @@ -1099,7 +1287,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac // tcpip.Endpoint. func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is - // implemented specifically for epsocket.SocketOperations rather than + // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket // options where the implementation is not shared, as unix sockets need // their own support for SO_TIMESTAMP. @@ -1162,7 +1350,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v))) case linux.SO_RCVBUF: if len(optVal) < sizeOfInt32 { @@ -1170,7 +1358,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v))) case linux.SO_REUSEADDR: if len(optVal) < sizeOfInt32 { @@ -1188,6 +1376,13 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v))) + case linux.SO_BINDTODEVICE: + n := bytes.IndexByte(optVal, 0) + if n == -1 { + n = len(optVal) + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n]))) + case linux.SO_BROADCAST: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -1380,6 +1575,19 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) t.Kernel().EmitUnimplementedEvent(t) + case linux.IPV6_TCLASS: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < -1 || v > 255 { + return syserr.ErrInvalidArgument + } + if v == -1 { + v = 0 + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v))) + default: emitUnimplementedEventIPv6(t, name) } @@ -1511,6 +1719,30 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s t.Kernel().EmitUnimplementedEvent(t) return syserr.ErrInvalidArgument + case linux.IP_TTL: + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + // -1 means default TTL. + if v == -1 { + v = 0 + } else if v < 1 || v > 255 { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v))) + + case linux.IP_TOS: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v))) + case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, linux.IP_BLOCK_SOURCE, @@ -1534,9 +1766,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_RECVTOS, linux.IP_RECVTTL, linux.IP_RETOPTS, - linux.IP_TOS, linux.IP_TRANSPARENT, - linux.IP_TTL, linux.IP_UNBLOCK_SOURCE, linux.IP_UNICAST_IF, linux.IP_XFRM_POLICY, @@ -2057,7 +2287,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] n, _, err = s.Endpoint.Write(v, opts) } dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= int64(v.Size()) || dontWait) { + if err == nil && (n >= v.src.NumBytes() || dontWait) { // Complete write. return int(n), nil } @@ -2082,7 +2312,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return 0, syserr.TranslateNetstackError(err) } - if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock { + if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock { return int(total), nil } @@ -2098,10 +2328,11 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] // Ioctl implements fs.FileOperations.Ioctl. func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { - // SIOCGSTAMP is implemented by epsocket rather than all commonEndpoint + // SIOCGSTAMP is implemented by netstack rather than all commonEndpoint // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. - if int(args[1].Int()) == syscall.SIOCGSTAMP { + switch args[1].Int() { + case syscall.SIOCGSTAMP: s.readMu.Lock() defer s.readMu.Unlock() if !s.timestampValid { @@ -2113,6 +2344,25 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, AddressSpaceActive: true, }) return 0, err + + case linux.TIOCINQ: + v, terr := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if terr != nil { + return 0, syserr.TranslateNetstackError(terr).ToError() + } + + // Add bytes removed from the endpoint but not yet sent to the caller. + v += len(s.readView) + + if v > math.MaxInt32 { + v = math.MaxInt32 + } + + // Copy result to user-space. + _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err } return Ioctl(ctx, s.Endpoint, io, args) @@ -2184,9 +2434,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc return 0, err case linux.TIOCOUTQ: - var v tcpip.SendQueueSizeOption - if err := ep.GetSockOpt(&v); err != nil { - return 0, syserr.TranslateNetstackError(err).ToError() + v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption) + if terr != nil { + return 0, syserr.TranslateNetstackError(terr).ToError() } if v > math.MaxInt32 { @@ -2381,7 +2631,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { // Flag values and meanings are described in greater detail in netdevice(7) in // the SIOCGIFFLAGS section. func interfaceStatusFlags(stack inet.Stack, name string) (uint32, *syserr.Error) { - // epsocket should only ever be passed an epsocket.Stack. + // We should only ever be passed a netstack.Stack. epstack, ok := stack.(*Stack) if !ok { return 0, errStackType @@ -2421,7 +2671,8 @@ func (s *SocketOperations) State() uint32 { return 0 } - if !s.isPacketBased() { + switch { + case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP: // TCP socket. switch tcp.EndpointState(s.Endpoint.State()) { case tcp.StateEstablished: @@ -2450,9 +2701,26 @@ func (s *SocketOperations) State() uint32 { // Internal or unknown state. return 0 } + case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP: + // UDP socket. + switch udp.EndpointState(s.Endpoint.State()) { + case udp.StateInitial, udp.StateBound, udp.StateClosed: + return linux.TCP_CLOSE + case udp.StateConnected: + return linux.TCP_ESTABLISHED + default: + return 0 + } + case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6: + // TODO(b/112063468): Export states for ICMP sockets. + case s.skType == linux.SOCK_RAW: + // TODO(b/112063468): Export states for raw sockets. + default: + // Unknown transport protocol, how did we make this socket? + log.Warningf("Unknown transport protocol for an existing socket: family=%v, type=%v, protocol=%v, internal type %v", s.family, s.skType, s.protocol, reflect.TypeOf(s.Endpoint).Elem()) + return 0 } - // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets. return 0 } diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/netstack/provider.go index 421f93dc4..357a664cc 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package epsocket +package netstack import ( "syscall" @@ -65,7 +65,7 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in // Raw sockets require CAP_NET_RAW. creds := auth.CredentialsFromContext(ctx) if !creds.HasCapability(linux.CAP_NET_RAW) { - return 0, true, syserr.ErrPermissionDenied + return 0, true, syserr.ErrNotPermitted } switch protocol { diff --git a/pkg/sentry/socket/epsocket/save_restore.go b/pkg/sentry/socket/netstack/save_restore.go index f7b8c10cc..c7aaf722a 100644 --- a/pkg/sentry/socket/epsocket/save_restore.go +++ b/pkg/sentry/socket/netstack/save_restore.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package epsocket +package netstack import ( "gvisor.dev/gvisor/pkg/tcpip/stack" diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/netstack/stack.go index 7cf7ff735..fda0156e5 100644 --- a/pkg/sentry/socket/epsocket/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package epsocket +package netstack import ( "gvisor.dev/gvisor/pkg/abi/linux" diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD index 5061dcbde..3a6baa308 100644 --- a/pkg/sentry/socket/rpcinet/BUILD +++ b/pkg/sentry/socket/rpcinet/BUILD @@ -1,5 +1,6 @@ load("//tools/go_stateify:defs.bzl", "go_library") load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") +load("@rules_cc//cc:defs.bzl", "cc_proto_library") package(licenses = ["notice"]) @@ -49,6 +50,14 @@ proto_library( ], ) +cc_proto_library( + name = "syscall_rpc_cc_proto", + visibility = [ + "//visibility:public", + ], + deps = [":syscall_rpc_proto"], +) + go_proto_library( name = "syscall_rpc_go_proto", importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto", diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index da9977fde..830f4da10 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -24,7 +24,7 @@ go_library( "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/control", - "//pkg/sentry/socket/epsocket", + "//pkg/sentry/socket/netstack", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserr", diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index 1c71609e2..e27b1c714 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -161,7 +161,8 @@ func (q *queue) Dequeue() (e *message, notify bool, err *syserr.Error) { if q.dataList.Front() == nil { err := syserr.ErrWouldBlock if q.closed { - if err = syserr.ErrClosedForReceive; q.unread { + err = syserr.ErrClosedForReceive + if q.unread { err = syserr.ErrConnectionReset } } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index a103b1aab..529a7a7a9 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -175,6 +175,10 @@ type Endpoint interface { // types. SetSockOpt(opt interface{}) *tcpip.Error + // SetSockOptInt sets a socket option for simple cases when a value has + // the int type. + SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + // GetSockOpt gets a socket option. opt should be a pointer to one of the // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error @@ -847,6 +851,10 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { return nil } +func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { + return nil +} + func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: @@ -862,65 +870,63 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return -1, tcpip.ErrQueueSizeNotSupported } return v, nil - default: - return -1, tcpip.ErrUnknownProtocolOption - } -} - -// GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { - case tcpip.ErrorOption: - return nil - case *tcpip.SendQueueSizeOption: + case tcpip.SendQueueSizeOption: e.Lock() if !e.Connected() { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize()) + v := e.connected.SendQueuedSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported - } - *o = qs - return nil - - case *tcpip.PasscredOption: - if e.Passcred() { - *o = tcpip.PasscredOption(1) - } else { - *o = tcpip.PasscredOption(0) + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported } - return nil + return int(v), nil - case *tcpip.SendBufferSizeOption: + case tcpip.SendBufferSizeOption: e.Lock() if !e.Connected() { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize()) + v := e.connected.SendMaxQueueSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported } - *o = qs - return nil + return int(v), nil - case *tcpip.ReceiveBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: e.Lock() if e.receiver == nil { e.Unlock() - return tcpip.ErrNotConnected + return -1, tcpip.ErrNotConnected } - qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize()) + v := e.receiver.RecvMaxQueueSize() e.Unlock() - if qs < 0 { - return tcpip.ErrQueueSizeNotSupported + if v < 0 { + return -1, tcpip.ErrQueueSizeNotSupported + } + return int(v), nil + + default: + return -1, tcpip.ErrUnknownProtocolOption + } +} + +// GetSockOpt implements tcpip.Endpoint.GetSockOpt. +func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { + switch o := opt.(type) { + case tcpip.ErrorOption: + return nil + + case *tcpip.PasscredOption: + if e.Passcred() { + *o = tcpip.PasscredOption(1) + } else { + *o = tcpip.PasscredOption(0) } - *o = qs return nil case *tcpip.KeepaliveEnabledOption: diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 8e4f06a22..1aaae8487 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -31,7 +31,7 @@ import ( ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/control" - "gvisor.dev/gvisor/pkg/sentry/socket/epsocket" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" @@ -40,8 +40,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -// SocketOperations is a Unix socket. It is similar to an epsocket, except it -// is backed by a transport.Endpoint instead of a tcpip.Endpoint. +// SocketOperations is a Unix socket. It is similar to a netstack socket, +// except it is backed by a transport.Endpoint instead of a tcpip.Endpoint. // // +stateify savable type SocketOperations struct { @@ -116,7 +116,7 @@ func (s *SocketOperations) Endpoint() transport.Endpoint { // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, _, err := epsocket.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) + addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) if err != nil { return "", err } @@ -143,7 +143,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, return nil, 0, syserr.TranslateNetstackError(err) } - a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr) + a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } @@ -155,19 +155,19 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, return nil, 0, syserr.TranslateNetstackError(err) } - a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr) + a, l := netstack.ConvertAddress(linux.AF_UNIX, addr) return a, l, nil } // Ioctl implements fs.FileOperations.Ioctl. func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { - return epsocket.Ioctl(ctx, s.ep, io, args) + return netstack.Ioctl(ctx, s.ep, io, args) } // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { - return epsocket.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) + return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } // Listen implements the linux syscall listen(2) for sockets backed by @@ -474,13 +474,13 @@ func (s *SocketOperations) EventUnregister(e *waiter.Entry) { // SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error { - return epsocket.SetSockOpt(t, s, s.ep, level, name, optVal) + return netstack.SetSockOpt(t, s, s.ep, level, name, optVal) } // Shutdown implements the linux syscall shutdown(2) for sockets backed by // a transport.Endpoint. func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { - f, err := epsocket.ConvertShutdown(how) + f, err := netstack.ConvertShutdown(how) if err != nil { return err } @@ -546,7 +546,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { - from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { @@ -581,7 +581,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags var from linux.SockAddr var fromLen uint32 if r.From != nil { - from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From) + from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From) } if r.ControlTrunc { |