diff options
Diffstat (limited to 'pkg/sentry/socket/epsocket/epsocket.go')
-rw-r--r-- | pkg/sentry/socket/epsocket/epsocket.go | 227 |
1 files changed, 201 insertions, 26 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index 635042263..5812085fa 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -26,13 +26,16 @@ package epsocket import ( "bytes" + "io" "math" + "reflect" "sync" "syscall" "time" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/context" @@ -52,6 +55,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" "gvisor.dev/gvisor/pkg/waiter" ) @@ -205,6 +209,10 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(interface{}) *tcpip.Error + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and + // transport.Endpoint.SetSockOptInt. + SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error @@ -224,7 +232,6 @@ type SocketOperations struct { fsutil.FileNoopFlush `state:"nosave"` fsutil.FileNoFsync `state:"nosave"` fsutil.FileNoMMap `state:"nosave"` - fsutil.FileNoSplice `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout *waiter.Queue @@ -409,17 +416,60 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS return int64(n), nil } -// ioSequencePayload implements tcpip.Payload. It copies user memory bytes on demand -// based on the requested size. +// WriteTo implements fs.FileOperations.WriteTo. +func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) { + s.readMu.Lock() + + // Copy as much data as possible. + done := int64(0) + for count > 0 { + // This may return a blocking error. + if err := s.fetchReadView(); err != nil { + s.readMu.Unlock() + return done, err.ToError() + } + + // Write to the underlying file. + n, err := dst.Write(s.readView) + done += int64(n) + count -= int64(n) + if dup { + // That's all we support for dup. This is generally + // supported by any Linux system calls, but the + // expectation is that now a caller will call read to + // actually remove these bytes from the socket. + break + } + + // Drop that part of the view. + s.readView.TrimFront(n) + if err != nil { + s.readMu.Unlock() + return done, err + } + } + + s.readMu.Unlock() + return done, nil +} + +// ioSequencePayload implements tcpip.Payload. +// +// t copies user memory bytes on demand based on the requested size. type ioSequencePayload struct { ctx context.Context src usermem.IOSequence } -// Get implements tcpip.Payload. -func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) { - if size > i.Size() { - size = i.Size() +// FullPayload implements tcpip.Payloader.FullPayload +func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) { + return i.Payload(int(i.src.NumBytes())) +} + +// Payload implements tcpip.Payloader.Payload. +func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) { + if max := int(i.src.NumBytes()); size > max { + size = max } v := buffer.NewView(size) if _, err := i.src.CopyIn(i.ctx, v); err != nil { @@ -428,11 +478,6 @@ func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) { return v, nil } -// Size implements tcpip.Payload. -func (i *ioSequencePayload) Size() int { - return int(i.src.NumBytes()) -} - // DropFirst drops the first n bytes from underlying src. func (i *ioSequencePayload) DropFirst(n int) { i.src = i.src.DropFirst(int(n)) @@ -466,6 +511,78 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO return int64(n), nil } +// readerPayload implements tcpip.Payloader. +// +// It allocates a view and reads from a reader on-demand, based on available +// capacity in the endpoint. +type readerPayload struct { + ctx context.Context + r io.Reader + count int64 + err error +} + +// FullPayload implements tcpip.Payloader.FullPayload. +func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) { + return r.Payload(int(r.count)) +} + +// Payload implements tcpip.Payloader.Payload. +func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) { + if size > int(r.count) { + size = int(r.count) + } + v := buffer.NewView(size) + n, err := r.r.Read(v) + if n > 0 { + // We ignore the error here. It may re-occur on subsequent + // reads, but for now we can enqueue some amount of data. + r.count -= int64(n) + return v[:n], nil + } + if err == syserror.ErrWouldBlock { + return nil, tcpip.ErrWouldBlock + } else if err != nil { + r.err = err // Save for propation. + return nil, tcpip.ErrBadAddress + } + + // There is no data and no error. Return an error, which will propagate + // r.err, which will be nil. This is the desired result: (0, nil). + return nil, tcpip.ErrBadAddress +} + +// ReadFrom implements fs.FileOperations.ReadFrom. +func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) { + f := &readerPayload{ctx: ctx, r: r, count: count} + n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{ + // Reads may be destructive but should be very fast, + // so we can't release the lock while copying data. + Atomic: true, + }) + if err == tcpip.ErrWouldBlock { + return 0, syserror.ErrWouldBlock + } + + if resCh != nil { + t := ctx.(*kernel.Task) + if err := t.Block(resCh); err != nil { + return 0, syserr.FromError(err).ToError() + } + + n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{ + Atomic: true, // See above. + }) + } + if err == tcpip.ErrWouldBlock { + return n, syserror.ErrWouldBlock + } else if err != nil { + return int64(n), f.err // Propagate error. + } + + return int64(n), nil +} + // Readiness returns a mask of ready events for socket s. func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { r := s.Endpoint.Readiness(mask) @@ -774,8 +891,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var size tcpip.SendBufferSizeOption - if err := ep.GetSockOpt(&size); err != nil { + size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -790,8 +907,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - var size tcpip.ReceiveBufferSizeOption - if err := ep.GetSockOpt(&size); err != nil { + size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -825,6 +942,19 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return int32(v), nil + case linux.SO_BINDTODEVICE: + var v tcpip.BindToDeviceOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + if len(v) == 0 { + return []byte{}, nil + } + if outLen < linux.IFNAMSIZ { + return nil, syserr.ErrInvalidArgument + } + return append([]byte(v), 0), nil + case linux.SO_BROADCAST: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument @@ -1162,7 +1292,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v))) case linux.SO_RCVBUF: if len(optVal) < sizeOfInt32 { @@ -1170,7 +1300,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v))) case linux.SO_REUSEADDR: if len(optVal) < sizeOfInt32 { @@ -1188,6 +1318,13 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i v := usermem.ByteOrder.Uint32(optVal) return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v))) + case linux.SO_BINDTODEVICE: + n := bytes.IndexByte(optVal, 0) + if n == -1 { + n = len(optVal) + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n]))) + case linux.SO_BROADCAST: if len(optVal) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -2057,7 +2194,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] n, _, err = s.Endpoint.Write(v, opts) } dontWait := flags&linux.MSG_DONTWAIT != 0 - if err == nil && (n >= int64(v.Size()) || dontWait) { + if err == nil && (n >= v.src.NumBytes() || dontWait) { // Complete write. return int(n), nil } @@ -2082,7 +2219,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return 0, syserr.TranslateNetstackError(err) } - if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock { + if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock { return int(total), nil } @@ -2101,7 +2238,8 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // SIOCGSTAMP is implemented by epsocket rather than all commonEndpoint // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. - if int(args[1].Int()) == syscall.SIOCGSTAMP { + switch args[1].Int() { + case syscall.SIOCGSTAMP: s.readMu.Lock() defer s.readMu.Unlock() if !s.timestampValid { @@ -2113,6 +2251,25 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, AddressSpaceActive: true, }) return 0, err + + case linux.TIOCINQ: + v, terr := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption) + if terr != nil { + return 0, syserr.TranslateNetstackError(terr).ToError() + } + + // Add bytes removed from the endpoint but not yet sent to the caller. + v += len(s.readView) + + if v > math.MaxInt32 { + v = math.MaxInt32 + } + + // Copy result to user-space. + _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ + AddressSpaceActive: true, + }) + return 0, err } return Ioctl(ctx, s.Endpoint, io, args) @@ -2184,9 +2341,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc return 0, err case linux.TIOCOUTQ: - var v tcpip.SendQueueSizeOption - if err := ep.GetSockOpt(&v); err != nil { - return 0, syserr.TranslateNetstackError(err).ToError() + v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption) + if terr != nil { + return 0, syserr.TranslateNetstackError(terr).ToError() } if v > math.MaxInt32 { @@ -2421,7 +2578,8 @@ func (s *SocketOperations) State() uint32 { return 0 } - if !s.isPacketBased() { + switch { + case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP: // TCP socket. switch tcp.EndpointState(s.Endpoint.State()) { case tcp.StateEstablished: @@ -2450,9 +2608,26 @@ func (s *SocketOperations) State() uint32 { // Internal or unknown state. return 0 } + case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP: + // UDP socket. + switch udp.EndpointState(s.Endpoint.State()) { + case udp.StateInitial, udp.StateBound, udp.StateClosed: + return linux.TCP_CLOSE + case udp.StateConnected: + return linux.TCP_ESTABLISHED + default: + return 0 + } + case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6: + // TODO(b/112063468): Export states for ICMP sockets. + case s.skType == linux.SOCK_RAW: + // TODO(b/112063468): Export states for raw sockets. + default: + // Unknown transport protocol, how did we make this socket? + log.Warningf("Unknown transport protocol for an existing socket: family=%v, type=%v, protocol=%v, internal type %v", s.family, s.skType, s.protocol, reflect.TypeOf(s.Endpoint).Elem()) + return 0 } - // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets. return 0 } |