From 3e9b8ecbfe21ba6c8c788be469fc6cea6a4a40b7 Mon Sep 17 00:00:00 2001 From: Ian Gudger Date: Thu, 13 Jun 2019 18:39:43 -0700 Subject: Plumb context through more layers of filesytem. All functions which allocate objects containing AtomicRefCounts will soon need a context. PiperOrigin-RevId: 253147709 --- pkg/sentry/socket/unix/io.go | 9 ++++++-- pkg/sentry/socket/unix/transport/BUILD | 1 + pkg/sentry/socket/unix/transport/connectioned.go | 24 ++++++++++++++-------- pkg/sentry/socket/unix/transport/connectionless.go | 22 ++++++++++++-------- pkg/sentry/socket/unix/transport/unix.go | 15 +++++++------- pkg/sentry/socket/unix/unix.go | 14 ++++++++----- 6 files changed, 54 insertions(+), 31 deletions(-) (limited to 'pkg/sentry/socket/unix') diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go index 023c2f135..760c7beab 100644 --- a/pkg/sentry/socket/unix/io.go +++ b/pkg/sentry/socket/unix/io.go @@ -15,6 +15,7 @@ package unix import ( + "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/tcpip" @@ -24,6 +25,8 @@ import ( // // EndpointWriter is not thread-safe. type EndpointWriter struct { + Ctx context.Context + // Endpoint is the transport.Endpoint to write to. Endpoint transport.Endpoint @@ -37,7 +40,7 @@ type EndpointWriter struct { // WriteFromBlocks implements safemem.Writer.WriteFromBlocks. func (w *EndpointWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) { return safemem.FromVecWriterFunc{func(bufs [][]byte) (int64, error) { - n, err := w.Endpoint.SendMsg(bufs, w.Control, w.To) + n, err := w.Endpoint.SendMsg(w.Ctx, bufs, w.Control, w.To) if err != nil { return int64(n), err.ToError() } @@ -50,6 +53,8 @@ func (w *EndpointWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) // // EndpointReader is not thread-safe. type EndpointReader struct { + Ctx context.Context + // Endpoint is the transport.Endpoint to read from. Endpoint transport.Endpoint @@ -81,7 +86,7 @@ type EndpointReader struct { // ReadToBlocks implements safemem.Reader.ReadToBlocks. func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) { - n, ms, c, ct, err := r.Endpoint.RecvMsg(bufs, r.Creds, r.NumRights, r.Peek, r.From) + n, ms, c, ct, err := r.Endpoint.RecvMsg(r.Ctx, bufs, r.Creds, r.NumRights, r.Peek, r.From) r.Control = c r.ControlTrunc = ct r.MsgSize = ms diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 82173dea7..0b0240336 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -31,6 +31,7 @@ go_library( "//pkg/abi/linux", "//pkg/ilist", "//pkg/refs", + "//pkg/sentry/context", "//pkg/syserr", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 9d07cde22..e4c416233 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -18,6 +18,7 @@ import ( "sync" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/waiter" @@ -111,8 +112,13 @@ type connectionedEndpoint struct { acceptedChan chan *connectionedEndpoint `state:".([]*connectionedEndpoint)"` } +var ( + _ = BoundEndpoint((*connectionedEndpoint)(nil)) + _ = Endpoint((*connectionedEndpoint)(nil)) +) + // NewConnectioned creates a new unbound connectionedEndpoint. -func NewConnectioned(stype linux.SockType, uid UniqueIDProvider) Endpoint { +func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint { return &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), @@ -122,7 +128,7 @@ func NewConnectioned(stype linux.SockType, uid UniqueIDProvider) Endpoint { } // NewPair allocates a new pair of connected unix-domain connectionedEndpoints. -func NewPair(stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { +func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { a := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), @@ -163,7 +169,7 @@ func NewPair(stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { // NewExternal creates a new externally backed Endpoint. It behaves like a // socketpair. -func NewExternal(stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { +func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { return &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected}, id: uid.UniqueID(), @@ -238,7 +244,7 @@ func (e *connectionedEndpoint) Close() { } // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. -func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error { +func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error { if ce.Type() != e.stype { return syserr.ErrConnectionRefused } @@ -334,19 +340,19 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur } // UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect. -func (e *connectionedEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *syserr.Error) { +func (e *connectionedEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) { return nil, syserr.ErrConnectionRefused } // Connect attempts to directly connect to another Endpoint. // Implements Endpoint.Connect. -func (e *connectionedEndpoint) Connect(server BoundEndpoint) *syserr.Error { +func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error { returnConnect := func(r Receiver, ce ConnectedEndpoint) { e.receiver = r e.connected = ce } - return server.BidirectionalConnect(e, returnConnect) + return server.BidirectionalConnect(ctx, e, returnConnect) } // Listen starts listening on the connection. @@ -426,13 +432,13 @@ func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syser // SendMsg writes data and a control message to the endpoint's peer. // This method does not block if the data cannot be written. -func (e *connectionedEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) { +func (e *connectionedEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) { // Stream sockets do not support specifying the endpoint. Seqpacket // sockets ignore the passed endpoint. if e.stype == linux.SOCK_STREAM && to != nil { return 0, syserr.ErrNotSupported } - return e.baseEndpoint.SendMsg(data, c, to) + return e.baseEndpoint.SendMsg(ctx, data, c, to) } // Readiness returns the current readiness of the connectionedEndpoint. For diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 254148286..cb2b60339 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -16,6 +16,7 @@ package transport import ( "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/waiter" @@ -32,8 +33,13 @@ type connectionlessEndpoint struct { baseEndpoint } +var ( + _ = BoundEndpoint((*connectionlessEndpoint)(nil)) + _ = Endpoint((*connectionlessEndpoint)(nil)) +) + // NewConnectionless creates a new unbound dgram endpoint. -func NewConnectionless() Endpoint { +func NewConnectionless(ctx context.Context) Endpoint { ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}} ep.receiver = &queueReceiver{readQueue: &queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}} return ep @@ -72,12 +78,12 @@ func (e *connectionlessEndpoint) Close() { } // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. -func (e *connectionlessEndpoint) BidirectionalConnect(ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error { +func (e *connectionlessEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error { return syserr.ErrConnectionRefused } // UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect. -func (e *connectionlessEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *syserr.Error) { +func (e *connectionlessEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) { e.Lock() r := e.receiver e.Unlock() @@ -96,12 +102,12 @@ func (e *connectionlessEndpoint) UnidirectionalConnect() (ConnectedEndpoint, *sy // SendMsg writes data and a control message to the specified endpoint. // This method does not block if the data cannot be written. -func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) { +func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) { if to == nil { - return e.baseEndpoint.SendMsg(data, c, nil) + return e.baseEndpoint.SendMsg(ctx, data, c, nil) } - connected, err := to.UnidirectionalConnect() + connected, err := to.UnidirectionalConnect(ctx) if err != nil { return 0, syserr.ErrInvalidEndpointState } @@ -124,8 +130,8 @@ func (e *connectionlessEndpoint) Type() linux.SockType { } // Connect attempts to connect directly to server. -func (e *connectionlessEndpoint) Connect(server BoundEndpoint) *syserr.Error { - connected, err := server.UnidirectionalConnect() +func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error { + connected, err := server.UnidirectionalConnect(ctx) if err != nil { return err } diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index a4d41e355..b0765ba55 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -20,6 +20,7 @@ import ( "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -120,13 +121,13 @@ type Endpoint interface { // CMTruncated indicates that the numRights hint was used to receive fewer // than the total available SCM_RIGHTS FDs. Additional truncation may be // required by the caller. - RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, err *syserr.Error) + RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen uintptr, cm ControlMessages, CMTruncated bool, err *syserr.Error) // SendMsg writes data and a control message to the endpoint's peer. // This method does not block if the data cannot be written. // // SendMsg does not take ownership of any of its arguments on error. - SendMsg([][]byte, ControlMessages, BoundEndpoint) (uintptr, *syserr.Error) + SendMsg(context.Context, [][]byte, ControlMessages, BoundEndpoint) (uintptr, *syserr.Error) // Connect connects this endpoint directly to another. // @@ -134,7 +135,7 @@ type Endpoint interface { // endpoint passed in as a parameter. // // The error codes are the same as Connect. - Connect(server BoundEndpoint) *syserr.Error + Connect(ctx context.Context, server BoundEndpoint) *syserr.Error // Shutdown closes the read and/or write end of the endpoint connection // to its peer. @@ -215,7 +216,7 @@ type BoundEndpoint interface { // // This method will return syserr.ErrConnectionRefused on endpoints with a // type that isn't SockStream or SockSeqpacket. - BidirectionalConnect(ep ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error + BidirectionalConnect(ctx context.Context, ep ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error // UnidirectionalConnect establishes a write-only connection to a unix // endpoint. @@ -225,7 +226,7 @@ type BoundEndpoint interface { // // This method will return syserr.ErrConnectionRefused on a non-SockDgram // endpoint. - UnidirectionalConnect() (ConnectedEndpoint, *syserr.Error) + UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) // Passcred returns whether or not the SO_PASSCRED socket option is // enabled on this end. @@ -776,7 +777,7 @@ func (e *baseEndpoint) Connected() bool { } // RecvMsg reads data and a control message from the endpoint. -func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, bool, *syserr.Error) { +func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights uintptr, peek bool, addr *tcpip.FullAddress) (uintptr, uintptr, ControlMessages, bool, *syserr.Error) { e.Lock() if e.receiver == nil { @@ -802,7 +803,7 @@ func (e *baseEndpoint) RecvMsg(data [][]byte, creds bool, numRights uintptr, pee // SendMsg writes data and a control message to the endpoint's peer. // This method does not block if the data cannot be written. -func (e *baseEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) { +func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) { e.Lock() if !e.Connected() { e.Unlock() diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 58483a279..97db87f3e 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -363,7 +363,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo defer ep.Release() // Connect the server endpoint. - return s.ep.Connect(ep) + return s.ep.Connect(t, ep) } // Writev implements fs.FileOperations.Write. @@ -372,11 +372,12 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO ctrl := control.New(t, s.ep, nil) if src.NumBytes() == 0 { - nInt, err := s.ep.SendMsg([][]byte{}, ctrl, nil) + nInt, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil) return int64(nInt), err.ToError() } return src.CopyInTo(ctx, &EndpointWriter{ + Ctx: ctx, Endpoint: s.ep, Control: ctrl, To: nil, @@ -387,6 +388,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO // a transport.Endpoint. func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { w := EndpointWriter{ + Ctx: t, Endpoint: s.ep, Control: controlMessages.Unix, To: nil, @@ -486,6 +488,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS return 0, nil } return dst.CopyOutFrom(ctx, &EndpointReader{ + Ctx: ctx, Endpoint: s.ep, NumRights: 0, Peek: false, @@ -522,6 +525,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } r := EndpointReader{ + Ctx: t, Endpoint: s.ep, Creds: wantCreds, NumRights: uintptr(numRights), @@ -635,9 +639,9 @@ func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs var ep transport.Endpoint switch stype { case linux.SOCK_DGRAM: - ep = transport.NewConnectionless() + ep = transport.NewConnectionless(t) case linux.SOCK_SEQPACKET, linux.SOCK_STREAM: - ep = transport.NewConnectioned(stype, t.Kernel()) + ep = transport.NewConnectioned(t, stype, t.Kernel()) default: return nil, syserr.ErrInvalidArgument } @@ -660,7 +664,7 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F } // Create the endpoints and sockets. - ep1, ep2 := transport.NewPair(stype, t.Kernel()) + ep1, ep2 := transport.NewPair(t, stype, t.Kernel()) s1 := New(t, ep1, stype) s2 := New(t, ep2, stype) -- cgit v1.2.3