From d3ed9baac0dc967eaf6d3e3f986cafe60604121a Mon Sep 17 00:00:00 2001 From: Michael Pratt Date: Wed, 5 Jun 2019 13:59:01 -0700 Subject: Implement dumpability tracking and checks We don't actually support core dumps, but some applications want to get/set dumpability, which still has an effect in procfs. Lack of support for set-uid binaries or fs creds simplifies things a bit. As-is, processes started via CreateProcess (i.e., init and sentryctl exec) have normal dumpability. I'm a bit torn on whether sentryctl exec tasks should be dumpable, but at least since they have no parent normal UID/GID checks should protect them. PiperOrigin-RevId: 251712714 --- pkg/abi/linux/prctl.go | 7 +++++++ 1 file changed, 7 insertions(+) (limited to 'pkg/abi/linux') diff --git a/pkg/abi/linux/prctl.go b/pkg/abi/linux/prctl.go index 0428282dd..391cfaa1c 100644 --- a/pkg/abi/linux/prctl.go +++ b/pkg/abi/linux/prctl.go @@ -155,3 +155,10 @@ const ( ARCH_GET_GS = 0x1004 ARCH_SET_CPUID = 0x1012 ) + +// Flags for prctl(PR_SET_DUMPABLE), defined in include/linux/sched/coredump.h. +const ( + SUID_DUMP_DISABLE = 0 + SUID_DUMP_USER = 1 + SUID_DUMP_ROOT = 2 +) -- cgit v1.2.3 From 2d2831e3541c8ae3c84f17cfd1bf0a26f2027044 Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Thu, 6 Jun 2019 15:03:44 -0700 Subject: Track and export socket state. This is necessary for implementing network diagnostic interfaces like /proc/net/{tcp,udp,unix} and sock_diag(7). For pass-through endpoints such as hostinet, we obtain the socket state from the backend. For netstack, we add explicit tracking of TCP states. PiperOrigin-RevId: 251934850 --- pkg/abi/linux/socket.go | 16 ++ pkg/sentry/fs/proc/net.go | 20 +-- pkg/sentry/socket/epsocket/epsocket.go | 44 +++++ pkg/sentry/socket/hostinet/socket.go | 24 +++ pkg/sentry/socket/netlink/socket.go | 5 + pkg/sentry/socket/rpcinet/socket.go | 6 + pkg/sentry/socket/socket.go | 4 + pkg/sentry/socket/unix/transport/BUILD | 1 + pkg/sentry/socket/unix/transport/connectioned.go | 9 ++ pkg/sentry/socket/unix/transport/connectionless.go | 16 ++ pkg/sentry/socket/unix/transport/unix.go | 4 + pkg/sentry/socket/unix/unix.go | 5 + pkg/tcpip/stack/transport_test.go | 4 + pkg/tcpip/tcpip.go | 4 + pkg/tcpip/transport/icmp/BUILD | 1 + pkg/tcpip/transport/icmp/endpoint.go | 6 + pkg/tcpip/transport/raw/endpoint.go | 5 + pkg/tcpip/transport/tcp/accept.go | 12 +- pkg/tcpip/transport/tcp/connect.go | 26 ++- pkg/tcpip/transport/tcp/endpoint.go | 174 ++++++++++++++------ pkg/tcpip/transport/tcp/endpoint_state.go | 42 ++--- pkg/tcpip/transport/tcp/rcv.go | 37 +++++ pkg/tcpip/transport/tcp/snd.go | 4 + pkg/tcpip/transport/tcp/tcp_test.go | 131 ++++++++++++--- pkg/tcpip/transport/tcp/testing/context/context.go | 39 ++++- pkg/tcpip/transport/udp/endpoint.go | 6 + test/syscalls/linux/proc_net_unix.cc | 178 +++++++++++++++++++++ 27 files changed, 696 insertions(+), 127 deletions(-) (limited to 'pkg/abi/linux') diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 417840731..44bd69df6 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -200,6 +200,22 @@ const ( SS_DISCONNECTING = 4 // In process of disconnecting. ) +// TCP protocol states, from include/net/tcp_states.h. +const ( + TCP_ESTABLISHED uint32 = iota + 1 + TCP_SYN_SENT + TCP_SYN_RECV + TCP_FIN_WAIT1 + TCP_FIN_WAIT2 + TCP_TIME_WAIT + TCP_CLOSE + TCP_CLOSE_WAIT + TCP_LAST_ACK + TCP_LISTEN + TCP_CLOSING + TCP_NEW_SYN_RECV +) + // SockAddrMax is the maximum size of a struct sockaddr, from // uapi/linux/socket.h. const SockAddrMax = 128 diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 4a107c739..3daaa962c 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -240,24 +240,6 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s } } - var sockState int - switch sops.Endpoint().Type() { - case linux.SOCK_DGRAM: - sockState = linux.SS_CONNECTING - // Unlike Linux, we don't have unbound connection-less sockets, - // so no SS_DISCONNECTING. - - case linux.SOCK_SEQPACKET: - fallthrough - case linux.SOCK_STREAM: - // Connectioned. - if sops.Endpoint().(transport.ConnectingEndpoint).Connected() { - sockState = linux.SS_CONNECTED - } else { - sockState = linux.SS_UNCONNECTED - } - } - // In the socket entry below, the value for the 'Num' field requires // some consideration. Linux prints the address to the struct // unix_sock representing a socket in the kernel, but may redact the @@ -282,7 +264,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s 0, // Protocol, always 0 for UDS. sockFlags, // Flags. sops.Endpoint().Type(), // Type. - sockState, // State. + sops.State(), // State. sfile.InodeID(), // Inode. ) diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index de4b963da..f91c5127a 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -52,6 +52,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -2281,3 +2282,46 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 { } return rv } + +// State implements socket.Socket.State. State translates the internal state +// returned by netstack to values defined by Linux. +func (s *SocketOperations) State() uint32 { + if s.family != linux.AF_INET && s.family != linux.AF_INET6 { + // States not implemented for this socket's family. + return 0 + } + + if !s.isPacketBased() { + // TCP socket. + switch tcp.EndpointState(s.Endpoint.State()) { + case tcp.StateEstablished: + return linux.TCP_ESTABLISHED + case tcp.StateSynSent: + return linux.TCP_SYN_SENT + case tcp.StateSynRecv: + return linux.TCP_SYN_RECV + case tcp.StateFinWait1: + return linux.TCP_FIN_WAIT1 + case tcp.StateFinWait2: + return linux.TCP_FIN_WAIT2 + case tcp.StateTimeWait: + return linux.TCP_TIME_WAIT + case tcp.StateClose, tcp.StateInitial, tcp.StateBound, tcp.StateConnecting, tcp.StateError: + return linux.TCP_CLOSE + case tcp.StateCloseWait: + return linux.TCP_CLOSE_WAIT + case tcp.StateLastAck: + return linux.TCP_LAST_ACK + case tcp.StateListen: + return linux.TCP_LISTEN + case tcp.StateClosing: + return linux.TCP_CLOSING + default: + // Internal or unknown state. + return 0 + } + } + + // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets. + return 0 +} diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 41f9693bb..0d75580a3 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -19,7 +19,9 @@ import ( "syscall" "gvisor.googlesource.com/gvisor/pkg/abi/linux" + "gvisor.googlesource.com/gvisor/pkg/binary" "gvisor.googlesource.com/gvisor/pkg/fdnotifier" + "gvisor.googlesource.com/gvisor/pkg/log" "gvisor.googlesource.com/gvisor/pkg/sentry/context" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil" @@ -519,6 +521,28 @@ func translateIOSyscallError(err error) error { return err } +// State implements socket.Socket.State. +func (s *socketOperations) State() uint32 { + info := linux.TCPInfo{} + buf, err := getsockopt(s.fd, syscall.SOL_TCP, syscall.TCP_INFO, linux.SizeOfTCPInfo) + if err != nil { + if err != syscall.ENOPROTOOPT { + log.Warningf("Failed to get TCP socket info from %+v: %v", s, err) + } + // For non-TCP sockets, silently ignore the failure. + return 0 + } + if len(buf) != linux.SizeOfTCPInfo { + // Unmarshal below will panic if getsockopt returns a buffer of + // unexpected size. + log.Warningf("Failed to get TCP socket info from %+v: getsockopt(2) returned %d bytes, expecting %d bytes.", s, len(buf), linux.SizeOfTCPInfo) + return 0 + } + + binary.Unmarshal(buf, usermem.ByteOrder, &info) + return uint32(info.State) +} + type socketProvider struct { family int } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index afd06ca33..16c79aa33 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -616,3 +616,8 @@ func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{}) return int64(n), err.ToError() } + +// State implements socket.Socket.State. +func (s *Socket) State() uint32 { + return s.ep.State() +} diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 55e0b6665..bf42bdf69 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -830,6 +830,12 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] } } +// State implements socket.Socket.State. +func (s *socketOperations) State() uint32 { + // TODO(b/127845868): Define a new rpc to query the socket state. + return 0 +} + type socketProvider struct { family int } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 9393acd28..a99423365 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -116,6 +116,10 @@ type Socket interface { // SendTimeout gets the current timeout (in ns) for send operations. Zero // means no timeout, and negative means DONTWAIT. SendTimeout() int64 + + // State returns the current state of the socket, as represented by Linux in + // procfs. The returned state value is protocol-specific. + State() uint32 } // Provider is the interface implemented by providers of sockets for specific diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 5a2de0c4c..52f324eed 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -28,6 +28,7 @@ go_library( importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport", visibility = ["//:sandbox"], deps = [ + "//pkg/abi/linux", "//pkg/ilist", "//pkg/refs", "//pkg/syserr", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 18e492862..9c8ec0365 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -17,6 +17,7 @@ package transport import ( "sync" + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/waiter" @@ -458,3 +459,11 @@ func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask return ready } + +// State implements socket.Socket.State. +func (e *connectionedEndpoint) State() uint32 { + if e.Connected() { + return linux.SS_CONNECTED + } + return linux.SS_UNCONNECTED +} diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 43ff875e4..c034cf984 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -15,6 +15,7 @@ package transport import ( + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/waiter" @@ -194,3 +195,18 @@ func (e *connectionlessEndpoint) Readiness(mask waiter.EventMask) waiter.EventMa return ready } + +// State implements socket.Socket.State. +func (e *connectionlessEndpoint) State() uint32 { + e.Lock() + defer e.Unlock() + + switch { + case e.isBound(): + return linux.SS_UNCONNECTED + case e.Connected(): + return linux.SS_CONNECTING + default: + return linux.SS_DISCONNECTING + } +} diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 37d82bb6b..5fc09af55 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -191,6 +191,10 @@ type Endpoint interface { // GetSockOpt gets a socket option. opt should be a pointer to one of the // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error + + // State returns the current state of the socket, as represented by Linux in + // procfs. + State() uint32 } // A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 388cc0d8b..375542350 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -596,6 +596,11 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } +// State implements socket.Socket.State. +func (s *SocketOperations) State() uint32 { + return s.ep.State() +} + // provider is a unix domain socket provider. type provider struct{} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 8d74f1543..e8a9392b5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -188,6 +188,10 @@ func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, s f.proto.controlCount++ } +func (f *fakeTransportEndpoint) State() uint32 { + return 0 +} + type fakeTransportGoodOption bool type fakeTransportBadOption bool diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f9886c6e4..85ef014d0 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -377,6 +377,10 @@ type Endpoint interface { // GetSockOpt gets a socket option. opt should be a pointer to one of the // *Option types. GetSockOpt(opt interface{}) *Error + + // State returns a socket's lifecycle state. The returned value is + // protocol-specific and is primarily used for diagnostics. + State() uint32 } // WriteOptions contains options for Endpoint.Write. diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 9aa6f3978..84a2b53b7 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -33,6 +33,7 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", + "//pkg/tcpip/transport/tcp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index e2b90ef10..b8005093a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -708,3 +708,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { } + +// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't +// expose internal socket state. +func (e *endpoint) State() uint32 { + return 0 +} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 1daf5823f..e4ff50c91 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -519,3 +519,8 @@ func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv b ep.waiterQueue.Notify(waiter.EventIn) } } + +// State implements socket.Socket.State. +func (ep *endpoint) State() uint32 { + return 0 +} diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 31e365ae5..a32e20b06 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -226,7 +226,6 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i } n.isRegistered = true - n.state = stateConnecting // Create sender and receiver. // @@ -258,8 +257,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head ep.Close() return nil, err } - - ep.state = stateConnected + ep.mu.Lock() + ep.state = StateEstablished + ep.mu.Unlock() // Update the receive window scaling. We can't do it before the // handshake because it's possible that the peer doesn't support window @@ -276,7 +276,7 @@ func (e *endpoint) deliverAccepted(n *endpoint) { e.mu.RLock() state := e.state e.mu.RUnlock() - if state == stateListen { + if state == StateListen { e.acceptedChan <- n e.waiterQueue.Notify(waiter.EventIn) } else { @@ -406,7 +406,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. - n.state = stateConnected + n.state = StateEstablished // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -429,7 +429,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { // handleSynSegment() from attempting to queue new connections // to the endpoint. e.mu.Lock() - e.state = stateClosed + e.state = StateClose // Do cleanup if needed. e.completeWorkerLocked() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 371d2ed29..0ad7bfb38 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -151,6 +151,9 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.mss = opts.MSS h.sndWndScale = opts.WS h.listenEP = listenEP + h.ep.mu.Lock() + h.ep.state = StateSynRecv + h.ep.mu.Unlock() } // checkAck checks if the ACK number, if present, of a segment received during @@ -219,6 +222,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // but resend our own SYN and wait for it to be acknowledged in the // SYN-RCVD state. h.state = handshakeSynRcvd + h.ep.mu.Lock() + h.ep.state = StateSynRecv + h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, TS: rcvSynOpts.TS, @@ -668,7 +674,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.state == stateConnected && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) @@ -719,8 +725,7 @@ func (e *endpoint) handleClose() *tcpip.Error { // protocol goroutine. func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) - - e.state = stateError + e.state = StateError e.hardError = err } @@ -876,14 +881,19 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // handshake, and then inform potential waiters about its // completion. h := newHandshake(e, seqnum.Size(e.receiveBufferAvailable())) + e.mu.Lock() + h.ep.state = StateSynSent + e.mu.Unlock() + if err := h.execute(); err != nil { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() e.mu.Lock() - e.state = stateError + e.state = StateError e.hardError = err + // Lock released below. epilogue() @@ -905,7 +915,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Tell waiters that the endpoint is connected and writable. e.mu.Lock() - e.state = stateConnected + e.state = StateEstablished drained := e.drainDone != nil e.mu.Unlock() if drained { @@ -1005,7 +1015,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return err } } - if e.state != stateError { + if e.state != StateError { close(e.drainDone) <-e.undrain } @@ -1061,8 +1071,8 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Mark endpoint as closed. e.mu.Lock() - if e.state != stateError { - e.state = stateClosed + if e.state != StateError { + e.state = StateClose } // Lock released below. epilogue() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index fd697402e..23422ca5e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -32,18 +32,81 @@ import ( "gvisor.googlesource.com/gvisor/pkg/waiter" ) -type endpointState int +// EndpointState represents the state of a TCP endpoint. +type EndpointState uint32 +// Endpoint states. Note that are represented in a netstack-specific manner and +// may not be meaningful externally. Specifically, they need to be translated to +// Linux's representation for these states if presented to userspace. const ( - stateInitial endpointState = iota - stateBound - stateListen - stateConnecting - stateConnected - stateClosed - stateError + // Endpoint states internal to netstack. These map to the TCP state CLOSED. + StateInitial EndpointState = iota + StateBound + StateConnecting // Connect() called, but the initial SYN hasn't been sent. + StateError + + // TCP protocol states. + StateEstablished + StateSynSent + StateSynRecv + StateFinWait1 + StateFinWait2 + StateTimeWait + StateClose + StateCloseWait + StateLastAck + StateListen + StateClosing ) +// connected is the set of states where an endpoint is connected to a peer. +func (s EndpointState) connected() bool { + switch s { + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + return true + default: + return false + } +} + +// String implements fmt.Stringer.String. +func (s EndpointState) String() string { + switch s { + case StateInitial: + return "INITIAL" + case StateBound: + return "BOUND" + case StateConnecting: + return "CONNECTING" + case StateError: + return "ERROR" + case StateEstablished: + return "ESTABLISHED" + case StateSynSent: + return "SYN-SENT" + case StateSynRecv: + return "SYN-RCVD" + case StateFinWait1: + return "FIN-WAIT1" + case StateFinWait2: + return "FIN-WAIT2" + case StateTimeWait: + return "TIME-WAIT" + case StateClose: + return "CLOSED" + case StateCloseWait: + return "CLOSE-WAIT" + case StateLastAck: + return "LAST-ACK" + case StateListen: + return "LISTEN" + case StateClosing: + return "CLOSING" + default: + panic("unreachable") + } +} + // Reasons for notifying the protocol goroutine. const ( notifyNonZeroReceiveWindow = 1 << iota @@ -108,10 +171,14 @@ type endpoint struct { rcvBufUsed int // The following fields are protected by the mutex. - mu sync.RWMutex `state:"nosave"` - id stack.TransportEndpointID - state endpointState `state:".(endpointState)"` - isPortReserved bool `state:"manual"` + mu sync.RWMutex `state:"nosave"` + id stack.TransportEndpointID + + // state endpointState `state:".(endpointState)"` + // pState ProtocolState + state EndpointState `state:".(EndpointState)"` + + isPortReserved bool `state:"manual"` isRegistered bool boundNICID tcpip.NICID `state:"manual"` route stack.Route `state:"manual"` @@ -304,6 +371,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite stack: stack, netProto: netProto, waiterQueue: waiterQueue, + state: StateInitial, rcvBufSize: DefaultBufferSize, sndBufSize: DefaultBufferSize, sndMTU: int(math.MaxInt32), @@ -351,14 +419,14 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { defer e.mu.RUnlock() switch e.state { - case stateInitial, stateBound, stateConnecting: + case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. - case stateClosed, stateError: + case StateClose, StateError: // Ready for anything. result = mask - case stateListen: + case StateListen: // Check if there's anything in the accepted channel. if (mask & waiter.EventIn) != 0 { if len(e.acceptedChan) > 0 { @@ -366,7 +434,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } } - case stateConnected: + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() @@ -427,7 +495,7 @@ func (e *endpoint) Close() { // are immediately available for reuse after Close() is called. If also // registered, we unregister as well otherwise the next user would fail // in Listen() when trying to register. - if e.state == stateListen && e.isPortReserved { + if e.state == StateListen && e.isPortReserved { if e.isRegistered { e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) e.isRegistered = false @@ -487,15 +555,15 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, e.mu.RLock() // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. Also note that a RST being received - // would cause the state to become stateError so we should allow the + // would cause the state to become StateError so we should allow the // reads to proceed before returning a ECONNRESET. e.rcvListMu.Lock() bufUsed := e.rcvBufUsed - if s := e.state; s != stateConnected && s != stateClosed && bufUsed == 0 { + if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.hardError e.mu.RUnlock() - if s == stateError { + if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState @@ -511,7 +579,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { if e.rcvBufUsed == 0 { - if e.rcvClosed || e.state != stateConnected { + if e.rcvClosed || !e.state.connected() { return buffer.View{}, tcpip.ErrClosedForReceive } return buffer.View{}, tcpip.ErrWouldBlock @@ -547,9 +615,9 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c defer e.mu.RUnlock() // The endpoint cannot be written to if it's not connected. - if e.state != stateConnected { + if !e.state.connected() { switch e.state { - case stateError: + case StateError: return 0, nil, e.hardError default: return 0, nil, tcpip.ErrClosedForSend @@ -612,8 +680,8 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. - if s := e.state; s != stateConnected && s != stateClosed { - if s == stateError { + if s := e.state; !s.connected() && s != StateClose { + if s == StateError { return 0, tcpip.ControlMessages{}, e.hardError } return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState @@ -623,7 +691,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er defer e.rcvListMu.Unlock() if e.rcvBufUsed == 0 { - if e.rcvClosed || e.state != stateConnected { + if e.rcvClosed || !e.state.connected() { return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock @@ -789,7 +857,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { defer e.mu.Unlock() // We only allow this to be set when we're in the initial state. - if e.state != stateInitial { + if e.state != StateInitial { return tcpip.ErrInvalidEndpointState } @@ -841,7 +909,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { defer e.mu.RUnlock() // The endpoint cannot be in listen state. - if e.state == stateListen { + if e.state == StateListen { return 0, tcpip.ErrInvalidEndpointState } @@ -1057,7 +1125,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er nicid := addr.NIC switch e.state { - case stateBound: + case StateBound: // If we're already bound to a NIC but the caller is requesting // that we use a different one now, we cannot proceed. if e.boundNICID == 0 { @@ -1070,16 +1138,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er nicid = e.boundNICID - case stateInitial: - // Nothing to do. We'll eventually fill-in the gaps in the ID - // (if any) when we find a route. + case StateInitial: + // Nothing to do. We'll eventually fill-in the gaps in the ID (if any) + // when we find a route. - case stateConnecting: - // A connection request has already been issued but hasn't - // completed yet. + case StateConnecting, StateSynSent, StateSynRecv: + // A connection request has already been issued but hasn't completed + // yet. return tcpip.ErrAlreadyConnecting - case stateConnected: + case StateEstablished: // The endpoint is already connected. If caller hasn't been notified yet, return success. if !e.isConnectNotified { e.isConnectNotified = true @@ -1088,7 +1156,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // Otherwise return that it's already connected. return tcpip.ErrAlreadyConnected - case stateError: + case StateError: return e.hardError default: @@ -1154,7 +1222,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er } e.isRegistered = true - e.state = stateConnecting + e.state = StateConnecting e.route = r.Clone() e.boundNICID = nicid e.effectiveNetProtos = netProtos @@ -1175,7 +1243,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er } e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) - e.state = stateConnected + e.state = StateEstablished } if run { @@ -1199,8 +1267,8 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { defer e.mu.Unlock() e.shutdownFlags |= flags - switch e.state { - case stateConnected: + switch { + case e.state.connected(): // Close for read. if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { // Mark read side as closed. @@ -1241,7 +1309,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.sndCloseWaker.Assert() } - case stateListen: + case e.state == StateListen: // Tell protocolListenLoop to stop. if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) @@ -1269,7 +1337,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { // When the endpoint shuts down, it sets workerCleanup to true, and from // that point onward, acceptedChan is the responsibility of the cleanup() // method (and should not be touched anywhere else, including here). - if e.state == stateListen && !e.workerCleanup { + if e.state == StateListen && !e.workerCleanup { // Adjust the size of the channel iff we can fix existing // pending connections into the new one. if len(e.acceptedChan) > backlog { @@ -1288,7 +1356,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { } // Endpoint must be bound before it can transition to listen mode. - if e.state != stateBound { + if e.state != StateBound { return tcpip.ErrInvalidEndpointState } @@ -1298,7 +1366,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { } e.isRegistered = true - e.state = stateListen + e.state = StateListen if e.acceptedChan == nil { e.acceptedChan = make(chan *endpoint, backlog) } @@ -1325,7 +1393,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { defer e.mu.RUnlock() // Endpoint must be in listen state before it can accept connections. - if e.state != stateListen { + if e.state != StateListen { return nil, nil, tcpip.ErrInvalidEndpointState } @@ -1353,7 +1421,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. - if e.state != stateInitial { + if e.state != StateInitial { return tcpip.ErrAlreadyBound } @@ -1408,7 +1476,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } // Mark endpoint as bound. - e.state = stateBound + e.state = StateBound return nil } @@ -1430,7 +1498,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != stateConnected { + if !e.state.connected() { return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -1739,3 +1807,11 @@ func (e *endpoint) initGSO() { gso.MaxSize = e.route.GSOMaxSize() e.gso = gso } + +// State implements tcpip.Endpoint.State. It exports the endpoint's protocol +// state for diagnostics. +func (e *endpoint) State() uint32 { + e.mu.Lock() + defer e.mu.Unlock() + return uint32(e.state) +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index e8aed2875..5f30c2374 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -49,8 +49,8 @@ func (e *endpoint) beforeSave() { defer e.mu.Unlock() switch e.state { - case stateInitial, stateBound: - case stateConnected: + case StateInitial, StateBound: + case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) @@ -66,17 +66,17 @@ func (e *endpoint) beforeSave() { break } fallthrough - case stateListen, stateConnecting: + case StateListen, StateConnecting: e.drainSegmentLocked() - if e.state != stateClosed && e.state != stateError { + if e.state != StateClose && e.state != StateError { if !e.workerRunning { panic("endpoint has no worker running in listen, connecting, or connected state") } break } fallthrough - case stateError, stateClosed: - for e.state == stateError && e.workerRunning { + case StateError, StateClose: + for e.state == StateError && e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() @@ -92,7 +92,7 @@ func (e *endpoint) beforeSave() { panic("endpoint still has waiters upon save") } - if e.state != stateClosed && !((e.state == stateBound || e.state == stateListen) == e.isPortReserved) { + if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) { panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state") } } @@ -132,7 +132,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { } // saveState is invoked by stateify. -func (e *endpoint) saveState() endpointState { +func (e *endpoint) saveState() EndpointState { return e.state } @@ -146,15 +146,15 @@ var connectingLoading sync.WaitGroup // Bound endpoint loading happens last. // loadState is invoked by stateify. -func (e *endpoint) loadState(state endpointState) { +func (e *endpoint) loadState(state EndpointState) { // This is to ensure that the loading wait groups include all applicable // endpoints before any asynchronous calls to the Wait() methods. switch state { - case stateConnected: + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: connectedLoading.Add(1) - case stateListen: + case StateListen: listenLoading.Add(1) - case stateConnecting: + case StateConnecting, StateSynSent, StateSynRecv: connectingLoading.Add(1) } e.state = state @@ -168,7 +168,7 @@ func (e *endpoint) afterLoad() { state := e.state switch state { - case stateInitial, stateBound, stateListen, stateConnecting, stateConnected: + case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { @@ -181,7 +181,7 @@ func (e *endpoint) afterLoad() { } bind := func() { - e.state = stateInitial + e.state = StateInitial if len(e.bindAddress) == 0 { e.bindAddress = e.id.LocalAddress } @@ -191,7 +191,7 @@ func (e *endpoint) afterLoad() { } switch state { - case stateConnected: + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: bind() if len(e.connectingAddress) == 0 { // This endpoint is accepted by netstack but not yet by @@ -211,7 +211,7 @@ func (e *endpoint) afterLoad() { panic("endpoint connecting failed: " + err.String()) } connectedLoading.Done() - case stateListen: + case StateListen: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -223,7 +223,7 @@ func (e *endpoint) afterLoad() { listenLoading.Done() tcpip.AsyncLoading.Done() }() - case stateConnecting: + case StateConnecting, StateSynSent, StateSynRecv: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -235,7 +235,7 @@ func (e *endpoint) afterLoad() { connectingLoading.Done() tcpip.AsyncLoading.Done() }() - case stateBound: + case StateBound: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -244,7 +244,7 @@ func (e *endpoint) afterLoad() { bind() tcpip.AsyncLoading.Done() }() - case stateClosed: + case StateClose: if e.isPortReserved { tcpip.AsyncLoading.Add(1) go func() { @@ -252,12 +252,12 @@ func (e *endpoint) afterLoad() { listenLoading.Wait() connectingLoading.Wait() bind() - e.state = stateClosed + e.state = StateClose tcpip.AsyncLoading.Done() }() } fallthrough - case stateError: + case StateError: tcpip.DeleteDanglingEndpoint(e) } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index b08a0e356..f02fa6105 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -134,6 +134,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // sequence numbers that have been consumed. TrimSACKBlockList(&r.ep.sack, r.rcvNxt) + // Handle FIN or FIN-ACK. if s.flagIsSet(header.TCPFlagFin) { r.rcvNxt++ @@ -144,6 +145,25 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum r.closed = true r.ep.readyToRead(nil) + // We just received a FIN, our next state depends on whether we sent a + // FIN already or not. + r.ep.mu.Lock() + switch r.ep.state { + case StateEstablished: + r.ep.state = StateCloseWait + case StateFinWait1: + if s.flagIsSet(header.TCPFlagAck) { + // FIN-ACK, transition to TIME-WAIT. + r.ep.state = StateTimeWait + } else { + // Simultaneous close, expecting a final ACK. + r.ep.state = StateClosing + } + case StateFinWait2: + r.ep.state = StateTimeWait + } + r.ep.mu.Unlock() + // Flush out any pending segments, except the very first one if // it happens to be the one we're handling now because the // caller is using it. @@ -156,6 +176,23 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum r.pendingRcvdSegments[i].decRef() } r.pendingRcvdSegments = r.pendingRcvdSegments[:first] + + return true + } + + // Handle ACK (not FIN-ACK, which we handled above) during one of the + // shutdown states. + if s.flagIsSet(header.TCPFlagAck) { + r.ep.mu.Lock() + switch r.ep.state { + case StateFinWait1: + r.ep.state = StateFinWait2 + case StateClosing: + r.ep.state = StateTimeWait + case StateLastAck: + r.ep.state = StateClose + } + r.ep.mu.Unlock() } return true diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 3464e4be7..b236d7af2 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -632,6 +632,10 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } seg.flags = header.TCPFlagAck | header.TCPFlagFin segEnd = seg.sequenceNumber.Add(1) + // Transition to FIN-WAIT1 state since we're initiating an active close. + s.ep.mu.Lock() + s.ep.state = StateFinWait1 + s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. if seg.flags&header.TCPFlagFin != 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index b8f0ccaf1..56b490aaa 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -168,8 +168,8 @@ func TestTCPResetsSentIncrement(t *testing.T) { // Receive the SYN-ACK reply. b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) ackHeaders := &context.Headers{ SrcPort: context.TestPort, @@ -269,8 +269,8 @@ func TestConnectResetAfterClose(t *testing.T) { time.Sleep(3 * time.Second) for { b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - if tcp.Flags() == header.TCPFlagAck|header.TCPFlagFin { + tcpHdr := header.TCP(header.IPv4(b).Payload()) + if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin { // This is a retransmit of the FIN, ignore it. continue } @@ -553,9 +553,13 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { // We shouldn't consume a sequence number on RST. checker.SeqNum(uint32(c.IRS)+1), )) + // The RST puts the endpoint into an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } - // This final should be ignored because an ACK on a reset doesn't - // mean anything. + // This final ACK should be ignored because an ACK on a reset doesn't mean + // anything. c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, @@ -618,6 +622,10 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.SeqNum(uint32(c.IRS)+1), )) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Cause a RST to be generated by closing the read end now since we have // unread data. c.EP.Shutdown(tcpip.ShutdownRead) @@ -630,6 +638,10 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { // We shouldn't consume a sequence number on RST. checker.SeqNum(uint32(c.IRS)+1), )) + // The RST puts the endpoint into an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } // The ACK to the FIN should now be rejected since the connection has been // closed by a RST. @@ -1510,8 +1522,8 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { for bytesReceived != dataLen { b := c.GetPacket() numPackets++ - tcp := header.TCP(header.IPv4(b).Payload()) - payloadLen := len(tcp.Payload()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + payloadLen := len(tcpHdr.Payload()) checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), @@ -1522,7 +1534,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { ) pdata := data[bytesReceived : bytesReceived+payloadLen] - if p := tcp.Payload(); !bytes.Equal(pdata, p) { + if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) { t.Fatalf("got data = %v, want = %v", p, pdata) } bytesReceived += payloadLen @@ -1530,7 +1542,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { if c.TimeStampEnabled { // If timestamp option is enabled, echo back the timestamp and increment // the TSEcr value included in the packet and send that back as the TSVal. - parsedOpts := tcp.ParsedOptions() + parsedOpts := tcpHdr.ParsedOptions() tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) options = tsOpt[:] @@ -1757,8 +1769,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { ), ) - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) // Wait for retransmit. time.Sleep(1 * time.Second) @@ -1766,8 +1778,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagSyn), - checker.SrcPort(tcp.SourcePort()), - checker.SeqNum(tcp.SequenceNumber()), + checker.SrcPort(tcpHdr.SourcePort()), + checker.SeqNum(tcpHdr.SequenceNumber()), checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), ), ) @@ -1775,8 +1787,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // Send SYN-ACK. iss := seqnum.Value(789) c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), Flags: header.TCPFlagSyn | header.TCPFlagAck, SeqNum: iss, AckNum: c.IRS.Add(1), @@ -2523,8 +2535,8 @@ func TestReceivedSegmentQueuing(t *testing.T) { checker.TCPFlags(header.TCPFlagAck), ), ) - tcp := header.TCP(header.IPv4(b).Payload()) - ack := seqnum.Value(tcp.AckNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + ack := seqnum.Value(tcpHdr.AckNumber()) if ack == last { break } @@ -2568,6 +2580,10 @@ func TestReadAfterClosedState(t *testing.T) { ), ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Send some data and acknowledge the FIN. data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -2589,9 +2605,15 @@ func TestReadAfterClosedState(t *testing.T) { ), ) - // Give the stack the chance to transition to closed state. + // Give the stack the chance to transition to closed state. Note that since + // both the sender and receiver are now closed, we effectively skip the + // TIME-WAIT state. time.Sleep(1 * time.Second) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Wait for receive to be notified. select { case <-ch: @@ -3680,9 +3702,15 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %v", err) } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } if err := c.EP.Listen(1); err != nil { t.Fatalf("Listen failed: %v", err) } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } stats := c.Stack().Stats() want := stats.TCP.PassiveConnectionOpenings.Value() + 1 @@ -3826,3 +3854,68 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { } } } + +func TestEndpointBindListenAcceptState(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + aep, _, err := ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + aep, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Listening endpoint remains in listen state. + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + ep.Close() + // Give worker goroutines time to receive the close notification. + time.Sleep(1 * time.Second) + if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Accepted endpoint remains open when the listen endpoint is closed. + if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + +} diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 6e12413c6..69a43b6f4 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -532,6 +532,9 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. if err != nil { c.t.Fatalf("NewEndpoint failed: %v", err) } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } if epRcvBuf != nil { if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { @@ -557,13 +560,16 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. checker.TCPFlags(header.TCPFlagSyn), ), ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) c.SendPacket(nil, &Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), Flags: header.TCPFlagSyn | header.TCPFlagAck, SeqNum: iss, AckNum: c.IRS.Add(1), @@ -591,8 +597,11 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. case <-time.After(1 * time.Second): c.t.Fatalf("Timed out waiting for connection") } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } - c.Port = tcp.SourcePort() + c.Port = tcpHdr.SourcePort() } // RawEndpoint is just a small wrapper around a TCP endpoint's state to make @@ -690,6 +699,9 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * if err != nil { c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } // Start connection attempt. waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -719,6 +731,10 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * }), ), ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + tcpSeg := header.TCP(header.IPv4(b).Payload()) synOptions := header.ParseSynOptions(tcpSeg.Options(), false) @@ -782,6 +798,9 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * case <-time.After(1 * time.Second): c.t.Fatalf("Timed out waiting for connection") } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } // Store the source port in use by the endpoint. c.Port = tcpSeg.SourcePort() @@ -821,10 +840,16 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } if err := ep.Listen(10); err != nil { c.t.Fatalf("Listen failed: %v", err) } + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) @@ -847,6 +872,10 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption c.t.Fatalf("Timed out waiting for accept") } } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + return rep } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 3d52a4f31..fa7278286 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1000,3 +1000,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { } + +// State implements socket.Socket.State. +func (e *endpoint) State() uint32 { + // TODO(b/112063468): Translate internal state to values returned by Linux. + return 0 +} diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc index 6d745f728..82d325c17 100644 --- a/test/syscalls/linux/proc_net_unix.cc +++ b/test/syscalls/linux/proc_net_unix.cc @@ -34,6 +34,16 @@ using absl::StrFormat; constexpr char kProcNetUnixHeader[] = "Num RefCount Protocol Flags Type St Inode Path"; +// Possible values of the "st" field in a /proc/net/unix entry. Source: Linux +// kernel, include/uapi/linux/net.h. +enum { + SS_FREE = 0, // Not allocated + SS_UNCONNECTED, // Unconnected to any socket + SS_CONNECTING, // In process of connecting + SS_CONNECTED, // Connected to socket + SS_DISCONNECTING // In process of disconnecting +}; + // UnixEntry represents a single entry from /proc/net/unix. struct UnixEntry { uintptr_t addr; @@ -71,7 +81,12 @@ PosixErrorOr> ProcNetUnixEntries() { bool skipped_header = false; std::vector entries; std::vector lines = absl::StrSplit(content, absl::ByAnyChar("\n")); + std::cerr << "" << std::endl; for (std::string line : lines) { + // Emit the proc entry to the test output to provide context for the test + // results. + std::cerr << line << std::endl; + if (!skipped_header) { EXPECT_EQ(line, kProcNetUnixHeader); skipped_header = true; @@ -139,6 +154,7 @@ PosixErrorOr> ProcNetUnixEntries() { entries.push_back(entry); } + std::cerr << "" << std::endl; return entries; } @@ -241,6 +257,168 @@ TEST(ProcNetUnix, SocketPair) { EXPECT_EQ(entries.size(), 2); } +TEST(ProcNetUnix, StreamSocketStateUnconnectedOnBind) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create()); + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + const std::string address = ExtractPath(sockets->first_addr()); + UnixEntry bind_entry; + ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); + EXPECT_EQ(bind_entry.state, SS_UNCONNECTED); +} + +TEST(ProcNetUnix, StreamSocketStateStateUnconnectedOnListen) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create()); + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + const std::string address = ExtractPath(sockets->first_addr()); + UnixEntry bind_entry; + ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); + EXPECT_EQ(bind_entry.state, SS_UNCONNECTED); + + ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + UnixEntry listen_entry; + ASSERT_TRUE( + FindByPath(entries, &listen_entry, ExtractPath(sockets->first_addr()))); + EXPECT_EQ(listen_entry.state, SS_UNCONNECTED); + // The bind and listen entries should refer to the same socket. + EXPECT_EQ(listen_entry.inode, bind_entry.inode); +} + +TEST(ProcNetUnix, StreamSocketStateStateConnectedOnAccept) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create()); + const std::string address = ExtractPath(sockets->first_addr()); + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + UnixEntry listen_entry; + ASSERT_TRUE( + FindByPath(entries, &listen_entry, ExtractPath(sockets->first_addr()))); + + ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + int clientfd; + ASSERT_THAT(clientfd = accept(sockets->first_fd(), nullptr, nullptr), + SyscallSucceeds()); + + // Find the entry for the accepted socket. UDS proc entries don't have a + // remote address, so we distinguish the accepted socket from the listen + // socket by checking for a different inode. + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + UnixEntry accept_entry; + ASSERT_TRUE(FindBy( + entries, &accept_entry, [address, listen_entry](const UnixEntry& e) { + return e.path == address && e.inode != listen_entry.inode; + })); + EXPECT_EQ(accept_entry.state, SS_CONNECTED); + // Listen entry should still be in SS_UNCONNECTED state. + ASSERT_TRUE(FindBy(entries, &listen_entry, + [&sockets, listen_entry](const UnixEntry& e) { + return e.path == ExtractPath(sockets->first_addr()) && + e.inode == listen_entry.inode; + })); + EXPECT_EQ(listen_entry.state, SS_UNCONNECTED); +} + +TEST(ProcNetUnix, DgramSocketStateDisconnectingOnBind) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_DGRAM).Create()); + + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + // On gVisor, the only two UDS on the system are the ones we just created and + // we rely on this to locate the test socket entries in the remainder of the + // test. On a generic Linux system, we have no easy way to locate the + // corresponding entries, as they don't have an address yet. + if (IsRunningOnGvisor()) { + ASSERT_EQ(entries.size(), 2); + for (auto e : entries) { + ASSERT_EQ(e.state, SS_DISCONNECTING); + } + } + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + const std::string address = ExtractPath(sockets->first_addr()); + UnixEntry bind_entry; + ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); + EXPECT_EQ(bind_entry.state, SS_UNCONNECTED); +} + +TEST(ProcNetUnix, DgramSocketStateConnectingOnConnect) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_DGRAM).Create()); + + std::vector entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + // On gVisor, the only two UDS on the system are the ones we just created and + // we rely on this to locate the test socket entries in the remainder of the + // test. On a generic Linux system, we have no easy way to locate the + // corresponding entries, as they don't have an address yet. + if (IsRunningOnGvisor()) { + ASSERT_EQ(entries.size(), 2); + for (auto e : entries) { + ASSERT_EQ(e.state, SS_DISCONNECTING); + } + } + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + const std::string address = ExtractPath(sockets->first_addr()); + UnixEntry bind_entry; + ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); + + ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + // Once again, we have no easy way to identify the connecting socket as it has + // no listed address. We can only identify the entry as the "non-bind socket + // entry" on gVisor, where we're guaranteed to have only the two entries we + // create during this test. + if (IsRunningOnGvisor()) { + ASSERT_EQ(entries.size(), 2); + UnixEntry connect_entry; + ASSERT_TRUE( + FindBy(entries, &connect_entry, [bind_entry](const UnixEntry& e) { + return e.inode != bind_entry.inode; + })); + EXPECT_EQ(connect_entry.state, SS_CONNECTING); + } +} + } // namespace } // namespace testing } // namespace gvisor -- cgit v1.2.3 From b3f104507d7a04c0ca058cbcacc5ff78d853f4ba Mon Sep 17 00:00:00 2001 From: Jamie Liu Date: Thu, 6 Jun 2019 16:27:09 -0700 Subject: "Implement" mbind(2). We still only advertise a single NUMA node, and ignore mempolicy accordingly, but mbind() at least now succeeds and has effects reflected by get_mempolicy(). Also fix handling of nodemasks: round sizes to unsigned long (as documented and done by Linux), and zero trailing bits when copying them out. PiperOrigin-RevId: 251950859 --- pkg/abi/linux/mm.go | 9 + pkg/sentry/kernel/task.go | 7 +- pkg/sentry/kernel/task_sched.go | 4 +- pkg/sentry/mm/mm.go | 6 + pkg/sentry/mm/syscalls.go | 53 +++++ pkg/sentry/mm/vma.go | 3 + pkg/sentry/syscalls/linux/BUILD | 1 + pkg/sentry/syscalls/linux/linux64.go | 3 +- pkg/sentry/syscalls/linux/sys_mempolicy.go | 312 +++++++++++++++++++++++++++++ pkg/sentry/syscalls/linux/sys_mmap.go | 145 -------------- test/syscalls/linux/BUILD | 1 + test/syscalls/linux/mempolicy.cc | 37 +++- 12 files changed, 426 insertions(+), 155 deletions(-) create mode 100644 pkg/sentry/syscalls/linux/sys_mempolicy.go (limited to 'pkg/abi/linux') diff --git a/pkg/abi/linux/mm.go b/pkg/abi/linux/mm.go index 0b02f938a..cd043dac3 100644 --- a/pkg/abi/linux/mm.go +++ b/pkg/abi/linux/mm.go @@ -114,3 +114,12 @@ const ( MPOL_MODE_FLAGS = (MPOL_F_STATIC_NODES | MPOL_F_RELATIVE_NODES) ) + +// Flags for mbind(2). +const ( + MPOL_MF_STRICT = 1 << 0 + MPOL_MF_MOVE = 1 << 1 + MPOL_MF_MOVE_ALL = 1 << 2 + + MPOL_MF_VALID = MPOL_MF_STRICT | MPOL_MF_MOVE | MPOL_MF_MOVE_ALL +) diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index f9378c2de..4d889422f 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -455,12 +455,13 @@ type Task struct { // single numa node, all policies are no-ops. We only track this information // so that we can return reasonable values if the application calls // get_mempolicy(2) after setting a non-default policy. Note that in the - // real syscall, nodemask can be longer than 4 bytes, but we always report a - // single node so never need to save more than a single bit. + // real syscall, nodemask can be longer than a single unsigned long, but we + // always report a single node so never need to save more than a single + // bit. // // numaPolicy and numaNodeMask are protected by mu. numaPolicy int32 - numaNodeMask uint32 + numaNodeMask uint64 // If netns is true, the task is in a non-root network namespace. Network // namespaces aren't currently implemented in full; being in a network diff --git a/pkg/sentry/kernel/task_sched.go b/pkg/sentry/kernel/task_sched.go index 5455f6ea9..1c94ab11b 100644 --- a/pkg/sentry/kernel/task_sched.go +++ b/pkg/sentry/kernel/task_sched.go @@ -622,14 +622,14 @@ func (t *Task) SetNiceness(n int) { } // NumaPolicy returns t's current numa policy. -func (t *Task) NumaPolicy() (policy int32, nodeMask uint32) { +func (t *Task) NumaPolicy() (policy int32, nodeMask uint64) { t.mu.Lock() defer t.mu.Unlock() return t.numaPolicy, t.numaNodeMask } // SetNumaPolicy sets t's numa policy. -func (t *Task) SetNumaPolicy(policy int32, nodeMask uint32) { +func (t *Task) SetNumaPolicy(policy int32, nodeMask uint64) { t.mu.Lock() defer t.mu.Unlock() t.numaPolicy = policy diff --git a/pkg/sentry/mm/mm.go b/pkg/sentry/mm/mm.go index 0a026ff8c..604866d04 100644 --- a/pkg/sentry/mm/mm.go +++ b/pkg/sentry/mm/mm.go @@ -276,6 +276,12 @@ type vma struct { mlockMode memmap.MLockMode + // numaPolicy is the NUMA policy for this vma set by mbind(). + numaPolicy int32 + + // numaNodemask is the NUMA nodemask for this vma set by mbind(). + numaNodemask uint64 + // If id is not nil, it controls the lifecycle of mappable and provides vma // metadata shown in /proc/[pid]/maps, and the vma holds a reference. id memmap.MappingIdentity diff --git a/pkg/sentry/mm/syscalls.go b/pkg/sentry/mm/syscalls.go index af1e53f5d..9cf136532 100644 --- a/pkg/sentry/mm/syscalls.go +++ b/pkg/sentry/mm/syscalls.go @@ -973,6 +973,59 @@ func (mm *MemoryManager) MLockAll(ctx context.Context, opts MLockAllOpts) error return nil } +// NumaPolicy implements the semantics of Linux's get_mempolicy(MPOL_F_ADDR). +func (mm *MemoryManager) NumaPolicy(addr usermem.Addr) (int32, uint64, error) { + mm.mappingMu.RLock() + defer mm.mappingMu.RUnlock() + vseg := mm.vmas.FindSegment(addr) + if !vseg.Ok() { + return 0, 0, syserror.EFAULT + } + vma := vseg.ValuePtr() + return vma.numaPolicy, vma.numaNodemask, nil +} + +// SetNumaPolicy implements the semantics of Linux's mbind(). +func (mm *MemoryManager) SetNumaPolicy(addr usermem.Addr, length uint64, policy int32, nodemask uint64) error { + if !addr.IsPageAligned() { + return syserror.EINVAL + } + // Linux allows this to overflow. + la, _ := usermem.Addr(length).RoundUp() + ar, ok := addr.ToRange(uint64(la)) + if !ok { + return syserror.EINVAL + } + if ar.Length() == 0 { + return nil + } + + mm.mappingMu.Lock() + defer mm.mappingMu.Unlock() + defer func() { + mm.vmas.MergeRange(ar) + mm.vmas.MergeAdjacent(ar) + }() + vseg := mm.vmas.LowerBoundSegment(ar.Start) + lastEnd := ar.Start + for { + if !vseg.Ok() || lastEnd < vseg.Start() { + // "EFAULT: ... there was an unmapped hole in the specified memory + // range specified [sic] by addr and len." - mbind(2) + return syserror.EFAULT + } + vseg = mm.vmas.Isolate(vseg, ar) + vma := vseg.ValuePtr() + vma.numaPolicy = policy + vma.numaNodemask = nodemask + lastEnd = vseg.End() + if ar.End <= lastEnd { + return nil + } + vseg, _ = vseg.NextNonEmpty() + } +} + // Decommit implements the semantics of Linux's madvise(MADV_DONTNEED). func (mm *MemoryManager) Decommit(addr usermem.Addr, length uint64) error { ar, ok := addr.ToRange(length) diff --git a/pkg/sentry/mm/vma.go b/pkg/sentry/mm/vma.go index 02203f79f..0af8de5b0 100644 --- a/pkg/sentry/mm/vma.go +++ b/pkg/sentry/mm/vma.go @@ -107,6 +107,7 @@ func (mm *MemoryManager) createVMALocked(ctx context.Context, opts memmap.MMapOp private: opts.Private, growsDown: opts.GrowsDown, mlockMode: opts.MLockMode, + numaPolicy: linux.MPOL_DEFAULT, id: opts.MappingIdentity, hint: opts.Hint, } @@ -436,6 +437,8 @@ func (vmaSetFunctions) Merge(ar1 usermem.AddrRange, vma1 vma, ar2 usermem.AddrRa vma1.private != vma2.private || vma1.growsDown != vma2.growsDown || vma1.mlockMode != vma2.mlockMode || + vma1.numaPolicy != vma2.numaPolicy || + vma1.numaNodemask != vma2.numaNodemask || vma1.id != vma2.id || vma1.hint != vma2.hint { return vma{}, false diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index f76989ae2..1c057526b 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -19,6 +19,7 @@ go_library( "sys_identity.go", "sys_inotify.go", "sys_lseek.go", + "sys_mempolicy.go", "sys_mmap.go", "sys_mount.go", "sys_pipe.go", diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 3e4d312af..ad88b1391 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -360,8 +360,7 @@ var AMD64 = &kernel.SyscallTable{ 235: Utimes, // @Syscall(Vserver, note:Not implemented by Linux) 236: syscalls.Error(syscall.ENOSYS), // Vserver, not implemented by Linux - // @Syscall(Mbind, returns:EPERM or ENOSYS, note:Returns EPERM if the process does not have cap_sys_nice; ENOSYS otherwise), TODO(b/117792295) - 237: syscalls.CapError(linux.CAP_SYS_NICE), // may require cap_sys_nice + 237: Mbind, 238: SetMempolicy, 239: GetMempolicy, // 240: @Syscall(MqOpen), TODO(b/29354921) diff --git a/pkg/sentry/syscalls/linux/sys_mempolicy.go b/pkg/sentry/syscalls/linux/sys_mempolicy.go new file mode 100644 index 000000000..652b2c206 --- /dev/null +++ b/pkg/sentry/syscalls/linux/sys_mempolicy.go @@ -0,0 +1,312 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package linux + +import ( + "fmt" + + "gvisor.googlesource.com/gvisor/pkg/abi/linux" + "gvisor.googlesource.com/gvisor/pkg/sentry/arch" + "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" + "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" + "gvisor.googlesource.com/gvisor/pkg/syserror" +) + +// We unconditionally report a single NUMA node. This also means that our +// "nodemask_t" is a single unsigned long (uint64). +const ( + maxNodes = 1 + allowedNodemask = (1 << maxNodes) - 1 +) + +func copyInNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32) (uint64, error) { + // "nodemask points to a bit mask of node IDs that contains up to maxnode + // bits. The bit mask size is rounded to the next multiple of + // sizeof(unsigned long), but the kernel will use bits only up to maxnode. + // A NULL value of nodemask or a maxnode value of zero specifies the empty + // set of nodes. If the value of maxnode is zero, the nodemask argument is + // ignored." - set_mempolicy(2). Unfortunately, most of this is inaccurate + // because of what appears to be a bug: mm/mempolicy.c:get_nodes() uses + // maxnode-1, not maxnode, as the number of bits. + bits := maxnode - 1 + if bits > usermem.PageSize*8 { // also handles overflow from maxnode == 0 + return 0, syserror.EINVAL + } + if bits == 0 { + return 0, nil + } + // Copy in the whole nodemask. + numUint64 := (bits + 63) / 64 + buf := t.CopyScratchBuffer(int(numUint64) * 8) + if _, err := t.CopyInBytes(addr, buf); err != nil { + return 0, err + } + val := usermem.ByteOrder.Uint64(buf) + // Check that only allowed bits in the first unsigned long in the nodemask + // are set. + if val&^allowedNodemask != 0 { + return 0, syserror.EINVAL + } + // Check that all remaining bits in the nodemask are 0. + for i := 8; i < len(buf); i++ { + if buf[i] != 0 { + return 0, syserror.EINVAL + } + } + return val, nil +} + +func copyOutNodemask(t *kernel.Task, addr usermem.Addr, maxnode uint32, val uint64) error { + // mm/mempolicy.c:copy_nodes_to_user() also uses maxnode-1 as the number of + // bits. + bits := maxnode - 1 + if bits > usermem.PageSize*8 { // also handles overflow from maxnode == 0 + return syserror.EINVAL + } + if bits == 0 { + return nil + } + // Copy out the first unsigned long in the nodemask. + buf := t.CopyScratchBuffer(8) + usermem.ByteOrder.PutUint64(buf, val) + if _, err := t.CopyOutBytes(addr, buf); err != nil { + return err + } + // Zero out remaining unsigned longs in the nodemask. + if bits > 64 { + remAddr, ok := addr.AddLength(8) + if !ok { + return syserror.EFAULT + } + remUint64 := (bits - 1) / 64 + if _, err := t.MemoryManager().ZeroOut(t, remAddr, int64(remUint64)*8, usermem.IOOpts{ + AddressSpaceActive: true, + }); err != nil { + return err + } + } + return nil +} + +// GetMempolicy implements the syscall get_mempolicy(2). +func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + mode := args[0].Pointer() + nodemask := args[1].Pointer() + maxnode := args[2].Uint() + addr := args[3].Pointer() + flags := args[4].Uint() + + if flags&^(linux.MPOL_F_NODE|linux.MPOL_F_ADDR|linux.MPOL_F_MEMS_ALLOWED) != 0 { + return 0, nil, syserror.EINVAL + } + nodeFlag := flags&linux.MPOL_F_NODE != 0 + addrFlag := flags&linux.MPOL_F_ADDR != 0 + memsAllowed := flags&linux.MPOL_F_MEMS_ALLOWED != 0 + + // "EINVAL: The value specified by maxnode is less than the number of node + // IDs supported by the system." - get_mempolicy(2) + if nodemask != 0 && maxnode < maxNodes { + return 0, nil, syserror.EINVAL + } + + // "If flags specifies MPOL_F_MEMS_ALLOWED [...], the mode argument is + // ignored and the set of nodes (memories) that the thread is allowed to + // specify in subsequent calls to mbind(2) or set_mempolicy(2) (in the + // absence of any mode flags) is returned in nodemask." + if memsAllowed { + // "It is not permitted to combine MPOL_F_MEMS_ALLOWED with either + // MPOL_F_ADDR or MPOL_F_NODE." + if nodeFlag || addrFlag { + return 0, nil, syserror.EINVAL + } + if err := copyOutNodemask(t, nodemask, maxnode, allowedNodemask); err != nil { + return 0, nil, err + } + return 0, nil, nil + } + + // "If flags specifies MPOL_F_ADDR, then information is returned about the + // policy governing the memory address given in addr. ... If the mode + // argument is not NULL, then get_mempolicy() will store the policy mode + // and any optional mode flags of the requested NUMA policy in the location + // pointed to by this argument. If nodemask is not NULL, then the nodemask + // associated with the policy will be stored in the location pointed to by + // this argument." + if addrFlag { + policy, nodemaskVal, err := t.MemoryManager().NumaPolicy(addr) + if err != nil { + return 0, nil, err + } + if nodeFlag { + // "If flags specifies both MPOL_F_NODE and MPOL_F_ADDR, + // get_mempolicy() will return the node ID of the node on which the + // address addr is allocated into the location pointed to by mode. + // If no page has yet been allocated for the specified address, + // get_mempolicy() will allocate a page as if the thread had + // performed a read (load) access to that address, and return the + // ID of the node where that page was allocated." + buf := t.CopyScratchBuffer(1) + _, err := t.CopyInBytes(addr, buf) + if err != nil { + return 0, nil, err + } + policy = 0 // maxNodes == 1 + } + if mode != 0 { + if _, err := t.CopyOut(mode, policy); err != nil { + return 0, nil, err + } + } + if nodemask != 0 { + if err := copyOutNodemask(t, nodemask, maxnode, nodemaskVal); err != nil { + return 0, nil, err + } + } + return 0, nil, nil + } + + // "EINVAL: ... flags specified MPOL_F_ADDR and addr is NULL, or flags did + // not specify MPOL_F_ADDR and addr is not NULL." This is partially + // inaccurate: if flags specifies MPOL_F_ADDR, + // mm/mempolicy.c:do_get_mempolicy() doesn't special-case NULL; it will + // just (usually) fail to find a VMA at address 0 and return EFAULT. + if addr != 0 { + return 0, nil, syserror.EINVAL + } + + // "If flags is specified as 0, then information about the calling thread's + // default policy (as set by set_mempolicy(2)) is returned, in the buffers + // pointed to by mode and nodemask. ... If flags specifies MPOL_F_NODE, but + // not MPOL_F_ADDR, and the thread's current policy is MPOL_INTERLEAVE, + // then get_mempolicy() will return in the location pointed to by a + // non-NULL mode argument, the node ID of the next node that will be used + // for interleaving of internal kernel pages allocated on behalf of the + // thread." + policy, nodemaskVal := t.NumaPolicy() + if nodeFlag { + if policy&^linux.MPOL_MODE_FLAGS != linux.MPOL_INTERLEAVE { + return 0, nil, syserror.EINVAL + } + policy = 0 // maxNodes == 1 + } + if mode != 0 { + if _, err := t.CopyOut(mode, policy); err != nil { + return 0, nil, err + } + } + if nodemask != 0 { + if err := copyOutNodemask(t, nodemask, maxnode, nodemaskVal); err != nil { + return 0, nil, err + } + } + return 0, nil, nil +} + +// SetMempolicy implements the syscall set_mempolicy(2). +func SetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + modeWithFlags := args[0].Int() + nodemask := args[1].Pointer() + maxnode := args[2].Uint() + + modeWithFlags, nodemaskVal, err := copyInMempolicyNodemask(t, modeWithFlags, nodemask, maxnode) + if err != nil { + return 0, nil, err + } + + t.SetNumaPolicy(modeWithFlags, nodemaskVal) + return 0, nil, nil +} + +// Mbind implements the syscall mbind(2). +func Mbind(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + addr := args[0].Pointer() + length := args[1].Uint64() + mode := args[2].Int() + nodemask := args[3].Pointer() + maxnode := args[4].Uint() + flags := args[5].Uint() + + if flags&^linux.MPOL_MF_VALID != 0 { + return 0, nil, syserror.EINVAL + } + // "If MPOL_MF_MOVE_ALL is passed in flags ... [the] calling thread must be + // privileged (CAP_SYS_NICE) to use this flag." - mbind(2) + if flags&linux.MPOL_MF_MOVE_ALL != 0 && !t.HasCapability(linux.CAP_SYS_NICE) { + return 0, nil, syserror.EPERM + } + + mode, nodemaskVal, err := copyInMempolicyNodemask(t, mode, nodemask, maxnode) + if err != nil { + return 0, nil, err + } + + // Since we claim to have only a single node, all flags can be ignored + // (since all pages must already be on that single node). + err = t.MemoryManager().SetNumaPolicy(addr, length, mode, nodemaskVal) + return 0, nil, err +} + +func copyInMempolicyNodemask(t *kernel.Task, modeWithFlags int32, nodemask usermem.Addr, maxnode uint32) (int32, uint64, error) { + flags := modeWithFlags & linux.MPOL_MODE_FLAGS + mode := modeWithFlags &^ linux.MPOL_MODE_FLAGS + if flags == linux.MPOL_MODE_FLAGS { + // Can't specify both mode flags simultaneously. + return 0, 0, syserror.EINVAL + } + if mode < 0 || mode >= linux.MPOL_MAX { + // Must specify a valid mode. + return 0, 0, syserror.EINVAL + } + + var nodemaskVal uint64 + if nodemask != 0 { + var err error + nodemaskVal, err = copyInNodemask(t, nodemask, maxnode) + if err != nil { + return 0, 0, err + } + } + + switch mode { + case linux.MPOL_DEFAULT: + // "nodemask must be specified as NULL." - set_mempolicy(2). This is inaccurate; + // Linux allows a nodemask to be specified, as long as it is empty. + if nodemaskVal != 0 { + return 0, 0, syserror.EINVAL + } + case linux.MPOL_BIND, linux.MPOL_INTERLEAVE: + // These require a non-empty nodemask. + if nodemaskVal == 0 { + return 0, 0, syserror.EINVAL + } + case linux.MPOL_PREFERRED: + // This permits an empty nodemask, as long as no flags are set. + if nodemaskVal == 0 && flags != 0 { + return 0, 0, syserror.EINVAL + } + case linux.MPOL_LOCAL: + // This requires an empty nodemask and no flags set ... + if nodemaskVal != 0 || flags != 0 { + return 0, 0, syserror.EINVAL + } + // ... and is implemented as MPOL_PREFERRED. + mode = linux.MPOL_PREFERRED + default: + // Unknown mode, which we should have rejected above. + panic(fmt.Sprintf("unknown mode: %v", mode)) + } + + return mode | flags, nodemaskVal, nil +} diff --git a/pkg/sentry/syscalls/linux/sys_mmap.go b/pkg/sentry/syscalls/linux/sys_mmap.go index 64a6e639c..9926f0ac5 100644 --- a/pkg/sentry/syscalls/linux/sys_mmap.go +++ b/pkg/sentry/syscalls/linux/sys_mmap.go @@ -204,151 +204,6 @@ func Madvise(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sysca } } -func copyOutIfNotNull(t *kernel.Task, ptr usermem.Addr, val interface{}) (int, error) { - if ptr != 0 { - return t.CopyOut(ptr, val) - } - return 0, nil -} - -// GetMempolicy implements the syscall get_mempolicy(2). -func GetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - mode := args[0].Pointer() - nodemask := args[1].Pointer() - maxnode := args[2].Uint() - addr := args[3].Pointer() - flags := args[4].Uint() - - memsAllowed := flags&linux.MPOL_F_MEMS_ALLOWED != 0 - nodeFlag := flags&linux.MPOL_F_NODE != 0 - addrFlag := flags&linux.MPOL_F_ADDR != 0 - - // TODO(rahat): Once sysfs is implemented, report a single numa node in - // /sys/devices/system/node. - if nodemask != 0 && maxnode < 1 { - return 0, nil, syserror.EINVAL - } - - // 'addr' provided iff 'addrFlag' set. - if addrFlag == (addr == 0) { - return 0, nil, syserror.EINVAL - } - - // Default policy for the thread. - if flags == 0 { - policy, nodemaskVal := t.NumaPolicy() - if _, err := copyOutIfNotNull(t, mode, policy); err != nil { - return 0, nil, syserror.EFAULT - } - if _, err := copyOutIfNotNull(t, nodemask, nodemaskVal); err != nil { - return 0, nil, syserror.EFAULT - } - return 0, nil, nil - } - - // Report all nodes available to caller. - if memsAllowed { - // MPOL_F_NODE and MPOL_F_ADDR not allowed with MPOL_F_MEMS_ALLOWED. - if nodeFlag || addrFlag { - return 0, nil, syserror.EINVAL - } - - // Report a single numa node. - if _, err := copyOutIfNotNull(t, nodemask, uint32(0x1)); err != nil { - return 0, nil, syserror.EFAULT - } - return 0, nil, nil - } - - if addrFlag { - if nodeFlag { - // Return the id for the node where 'addr' resides, via 'mode'. - // - // The real get_mempolicy(2) allocates the page referenced by 'addr' - // by simulating a read, if it is unallocated before the call. It - // then returns the node the page is allocated on through the mode - // pointer. - b := t.CopyScratchBuffer(1) - _, err := t.CopyInBytes(addr, b) - if err != nil { - return 0, nil, syserror.EFAULT - } - if _, err := copyOutIfNotNull(t, mode, int32(0)); err != nil { - return 0, nil, syserror.EFAULT - } - } else { - storedPolicy, _ := t.NumaPolicy() - // Return the policy governing the memory referenced by 'addr'. - if _, err := copyOutIfNotNull(t, mode, int32(storedPolicy)); err != nil { - return 0, nil, syserror.EFAULT - } - } - return 0, nil, nil - } - - storedPolicy, _ := t.NumaPolicy() - if nodeFlag && (storedPolicy&^linux.MPOL_MODE_FLAGS == linux.MPOL_INTERLEAVE) { - // Policy for current thread is to interleave memory between - // nodes. Return the next node we'll allocate on. Since we only have a - // single node, this is always node 0. - if _, err := copyOutIfNotNull(t, mode, int32(0)); err != nil { - return 0, nil, syserror.EFAULT - } - return 0, nil, nil - } - - return 0, nil, syserror.EINVAL -} - -func allowedNodesMask() uint32 { - const maxNodes = 1 - return ^uint32((1 << maxNodes) - 1) -} - -// SetMempolicy implements the syscall set_mempolicy(2). -func SetMempolicy(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { - modeWithFlags := args[0].Int() - nodemask := args[1].Pointer() - maxnode := args[2].Uint() - - if nodemask != 0 && maxnode < 1 { - return 0, nil, syserror.EINVAL - } - - if modeWithFlags&linux.MPOL_MODE_FLAGS == linux.MPOL_MODE_FLAGS { - // Can't specify multiple modes simultaneously. - return 0, nil, syserror.EINVAL - } - - mode := modeWithFlags &^ linux.MPOL_MODE_FLAGS - if mode < 0 || mode >= linux.MPOL_MAX { - // Must specify a valid mode. - return 0, nil, syserror.EINVAL - } - - var nodemaskVal uint32 - // Nodemask may be empty for some policy modes. - if nodemask != 0 && maxnode > 0 { - if _, err := t.CopyIn(nodemask, &nodemaskVal); err != nil { - return 0, nil, syserror.EFAULT - } - } - - if (mode == linux.MPOL_INTERLEAVE || mode == linux.MPOL_BIND) && nodemaskVal == 0 { - // Mode requires a non-empty nodemask, but got an empty nodemask. - return 0, nil, syserror.EINVAL - } - - if nodemaskVal&allowedNodesMask() != 0 { - // Invalid node specified. - return 0, nil, syserror.EINVAL - } - - t.SetNumaPolicy(int32(modeWithFlags), nodemaskVal) - - return 0, nil, nil -} - // Mincore implements the syscall mincore(2). func Mincore(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { addr := args[0].Pointer() diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 0cb7b47b6..9bafc6e4f 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -999,6 +999,7 @@ cc_binary( linkstatic = 1, deps = [ "//test/util:cleanup", + "//test/util:memory_util", "//test/util:test_main", "//test/util:test_util", "//test/util:thread_util", diff --git a/test/syscalls/linux/mempolicy.cc b/test/syscalls/linux/mempolicy.cc index 4ac4cb88f..9d5f47651 100644 --- a/test/syscalls/linux/mempolicy.cc +++ b/test/syscalls/linux/mempolicy.cc @@ -18,6 +18,7 @@ #include "gtest/gtest.h" #include "absl/memory/memory.h" #include "test/util/cleanup.h" +#include "test/util/memory_util.h" #include "test/util/test_util.h" #include "test/util/thread_util.h" @@ -34,7 +35,7 @@ namespace { #define MPOL_PREFERRED 1 #define MPOL_BIND 2 #define MPOL_INTERLEAVE 3 -#define MPOL_MAX MPOL_INTERLEAVE +#define MPOL_LOCAL 4 #define MPOL_F_NODE (1 << 0) #define MPOL_F_ADDR (1 << 1) #define MPOL_F_MEMS_ALLOWED (1 << 2) @@ -44,11 +45,17 @@ namespace { int get_mempolicy(int *policy, uint64_t *nmask, uint64_t maxnode, void *addr, int flags) { - return syscall(__NR_get_mempolicy, policy, nmask, maxnode, addr, flags); + return syscall(SYS_get_mempolicy, policy, nmask, maxnode, addr, flags); } int set_mempolicy(int mode, uint64_t *nmask, uint64_t maxnode) { - return syscall(__NR_set_mempolicy, mode, nmask, maxnode); + return syscall(SYS_set_mempolicy, mode, nmask, maxnode); +} + +int mbind(void *addr, unsigned long len, int mode, + const unsigned long *nodemask, unsigned long maxnode, + unsigned flags) { + return syscall(SYS_mbind, addr, len, mode, nodemask, maxnode, flags); } // Creates a cleanup object that resets the calling thread's mempolicy to the @@ -252,6 +259,30 @@ TEST(MempolicyTest, GetMempolicyNextInterleaveNode) { EXPECT_EQ(0, mode); } +TEST(MempolicyTest, Mbind) { + // Temporarily set the thread policy to MPOL_PREFERRED. + const auto cleanup_thread_policy = + ASSERT_NO_ERRNO_AND_VALUE(ScopedSetMempolicy(MPOL_PREFERRED, nullptr, 0)); + + const auto mapping = ASSERT_NO_ERRNO_AND_VALUE( + MmapAnon(kPageSize, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANONYMOUS)); + + // vmas default to MPOL_DEFAULT irrespective of the thread policy (currently + // MPOL_PREFERRED). + int mode; + ASSERT_THAT(get_mempolicy(&mode, nullptr, 0, mapping.ptr(), MPOL_F_ADDR), + SyscallSucceeds()); + EXPECT_EQ(mode, MPOL_DEFAULT); + + // Set MPOL_PREFERRED for the vma and read it back. + ASSERT_THAT( + mbind(mapping.ptr(), mapping.len(), MPOL_PREFERRED, nullptr, 0, 0), + SyscallSucceeds()); + ASSERT_THAT(get_mempolicy(&mode, nullptr, 0, mapping.ptr(), MPOL_F_ADDR), + SyscallSucceeds()); + EXPECT_EQ(mode, MPOL_PREFERRED); +} + } // namespace } // namespace testing -- cgit v1.2.3 From 315cf9a523d409dc6ddd5ce25f8f0315068ccc67 Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Thu, 6 Jun 2019 16:59:21 -0700 Subject: Use common definition of SockType. SockType isn't specific to unix domain sockets, and the current definition basically mirrors the linux ABI's definition. PiperOrigin-RevId: 251956740 --- pkg/abi/linux/socket.go | 18 +++++++++++------- pkg/sentry/fs/gofer/socket.go | 11 ++++++----- pkg/sentry/fs/host/socket.go | 11 ++++++----- pkg/sentry/socket/epsocket/BUILD | 1 - pkg/sentry/socket/epsocket/epsocket.go | 11 +++++------ pkg/sentry/socket/epsocket/provider.go | 7 +++---- pkg/sentry/socket/hostinet/BUILD | 1 - pkg/sentry/socket/hostinet/socket.go | 5 ++--- pkg/sentry/socket/netlink/provider.go | 7 +++---- pkg/sentry/socket/rpcinet/BUILD | 1 - pkg/sentry/socket/rpcinet/socket.go | 9 ++++----- pkg/sentry/socket/socket.go | 8 ++++---- pkg/sentry/socket/unix/transport/connectioned.go | 20 ++++++++++---------- pkg/sentry/socket/unix/transport/connectionless.go | 4 ++-- pkg/sentry/socket/unix/transport/unix.go | 22 ++++------------------ pkg/sentry/socket/unix/unix.go | 4 ++-- pkg/sentry/strace/socket.go | 14 +++++++------- pkg/sentry/syscalls/linux/sys_socket.go | 4 ++-- 18 files changed, 71 insertions(+), 87 deletions(-) (limited to 'pkg/abi/linux') diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 44bd69df6..a714ac86d 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -102,15 +102,19 @@ const ( SOL_NETLINK = 270 ) +// A SockType is a type (as opposed to family) of sockets. These are enumerated +// below as SOCK_* constants. +type SockType int + // Socket types, from linux/net.h. const ( - SOCK_STREAM = 1 - SOCK_DGRAM = 2 - SOCK_RAW = 3 - SOCK_RDM = 4 - SOCK_SEQPACKET = 5 - SOCK_DCCP = 6 - SOCK_PACKET = 10 + SOCK_STREAM SockType = 1 + SOCK_DGRAM = 2 + SOCK_RAW = 3 + SOCK_RDM = 4 + SOCK_SEQPACKET = 5 + SOCK_DCCP = 6 + SOCK_PACKET = 10 ) // SOCK_TYPE_MASK covers all of the above socket types. The remaining bits are diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go index 7376fd76f..7ac0a421f 100644 --- a/pkg/sentry/fs/gofer/socket.go +++ b/pkg/sentry/fs/gofer/socket.go @@ -15,6 +15,7 @@ package gofer import ( + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/log" "gvisor.googlesource.com/gvisor/pkg/p9" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" @@ -61,13 +62,13 @@ type endpoint struct { path string } -func unixSockToP9(t transport.SockType) (p9.ConnectFlags, bool) { +func sockTypeToP9(t linux.SockType) (p9.ConnectFlags, bool) { switch t { - case transport.SockStream: + case linux.SOCK_STREAM: return p9.StreamSocket, true - case transport.SockSeqpacket: + case linux.SOCK_SEQPACKET: return p9.SeqpacketSocket, true - case transport.SockDgram: + case linux.SOCK_DGRAM: return p9.DgramSocket, true } return 0, false @@ -75,7 +76,7 @@ func unixSockToP9(t transport.SockType) (p9.ConnectFlags, bool) { // BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect. func (e *endpoint) BidirectionalConnect(ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error { - cf, ok := unixSockToP9(ce.Type()) + cf, ok := sockTypeToP9(ce.Type()) if !ok { return syserr.ErrConnectionRefused } diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index e4ec0f62c..6423ad938 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -19,6 +19,7 @@ import ( "sync" "syscall" + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/fd" "gvisor.googlesource.com/gvisor/pkg/fdnotifier" "gvisor.googlesource.com/gvisor/pkg/log" @@ -56,7 +57,7 @@ type ConnectedEndpoint struct { srfd int `state:"wait"` // stype is the type of Unix socket. - stype transport.SockType + stype linux.SockType // sndbuf is the size of the send buffer. // @@ -105,7 +106,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error { return syserr.ErrInvalidEndpointState } - c.stype = transport.SockType(stype) + c.stype = linux.SockType(stype) c.sndbuf = sndbuf return nil @@ -163,7 +164,7 @@ func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.F ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.NewWithDirent(ctx, d, ep, e.stype != transport.SockStream, flags), nil + return unixsocket.NewWithDirent(ctx, d, ep, e.stype != linux.SOCK_STREAM, flags), nil } // newSocket allocates a new unix socket with host endpoint. @@ -195,7 +196,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.New(ctx, ep, e.stype != transport.SockStream), nil + return unixsocket.New(ctx, ep, e.stype != linux.SOCK_STREAM), nil } // Send implements transport.ConnectedEndpoint.Send. @@ -209,7 +210,7 @@ func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.Contro // Since stream sockets don't preserve message boundaries, we can write // only as much of the message as fits in the send buffer. - truncate := c.stype == transport.SockStream + truncate := c.stype == linux.SOCK_STREAM n, totalLen, err := fdWriteVec(c.file.FD(), data, c.sndbuf, truncate) if n < totalLen && err == nil { diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/epsocket/BUILD index 44bb97b5b..7e2679ea0 100644 --- a/pkg/sentry/socket/epsocket/BUILD +++ b/pkg/sentry/socket/epsocket/BUILD @@ -32,7 +32,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", - "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", "//pkg/syserr", diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index f91c5127a..e1e29de35 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -44,7 +44,6 @@ import ( ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/safemem" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/unimpl" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" @@ -228,7 +227,7 @@ type SocketOperations struct { family int Endpoint tcpip.Endpoint - skType transport.SockType + skType linux.SockType // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` @@ -253,8 +252,8 @@ type SocketOperations struct { } // New creates a new endpoint socket. -func New(t *kernel.Task, family int, skType transport.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { - if skType == transport.SockStream { +func New(t *kernel.Task, family int, skType linux.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { + if skType == linux.SOCK_STREAM { if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -638,7 +637,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) ( // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -664,7 +663,7 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, } // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_TYPE: diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go index ec930d8d5..e48a106ea 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/epsocket/provider.go @@ -23,7 +23,6 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/header" @@ -42,7 +41,7 @@ type provider struct { // getTransportProtocol figures out transport protocol. Currently only TCP, // UDP, and ICMP are supported. -func getTransportProtocol(ctx context.Context, stype transport.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) { +func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) { switch stype { case linux.SOCK_STREAM: if protocol != 0 && protocol != syscall.IPPROTO_TCP { @@ -80,7 +79,7 @@ func getTransportProtocol(ctx context.Context, stype transport.SockType, protoco } // Socket creates a new socket object for the AF_INET or AF_INET6 family. -func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Fail right away if we don't have a stack. stack := t.NetworkContext() if stack == nil { @@ -116,7 +115,7 @@ func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int } // Pair just returns nil sockets (not supported). -func (*provider) Pair(*kernel.Task, transport.SockType, int) (*fs.File, *fs.File, *syserr.Error) { +func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { return nil, nil, nil } diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index a469af7ac..975f47bc3 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -30,7 +30,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", - "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 0d75580a3..4517951a0 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -30,7 +30,6 @@ import ( ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/safemem" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/syserror" @@ -548,7 +547,7 @@ type socketProvider struct { } // Socket implements socket.Provider.Socket. -func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Check that we are using the host network stack. stack := t.NetworkContext() if stack == nil { @@ -590,7 +589,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p } // Pair implements socket.Provider.Pair. -func (p *socketProvider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { // Not supported by AF_INET/AF_INET6. return nil, nil, nil } diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 76cf12fd4..863edc241 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -22,7 +22,6 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/syserr" ) @@ -66,10 +65,10 @@ type socketProvider struct { } // Socket implements socket.Provider.Socket. -func (*socketProvider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Netlink sockets must be specified as datagram or raw, but they // behave the same regardless of type. - if stype != transport.SockDgram && stype != transport.SockRaw { + if stype != linux.SOCK_DGRAM && stype != linux.SOCK_RAW { return nil, syserr.ErrSocketNotSupported } @@ -94,7 +93,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype transport.SockType, protocol } // Pair implements socket.Provider.Pair by returning an error. -func (*socketProvider) Pair(*kernel.Task, transport.SockType, int) (*fs.File, *fs.File, *syserr.Error) { +func (*socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { // Netlink sockets never supports creating socket pairs. return nil, nil, syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD index 4da14a1e0..33ba20de7 100644 --- a/pkg/sentry/socket/rpcinet/BUILD +++ b/pkg/sentry/socket/rpcinet/BUILD @@ -31,7 +31,6 @@ go_library( "//pkg/sentry/socket/hostinet", "//pkg/sentry/socket/rpcinet/conn", "//pkg/sentry/socket/rpcinet/notifier", - "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", "//pkg/syserr", diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index bf42bdf69..2d5b5b58f 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -32,7 +32,6 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/conn" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/notifier" pb "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/unimpl" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" @@ -70,7 +69,7 @@ type socketOperations struct { var _ = socket.Socket(&socketOperations{}) // New creates a new RPC socket. -func newSocketFile(ctx context.Context, stack *Stack, family int, skType int, protocol int) (*fs.File, *syserr.Error) { +func newSocketFile(ctx context.Context, stack *Stack, family int, skType linux.SockType, protocol int) (*fs.File, *syserr.Error) { id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(family), Type: int64(skType | syscall.SOCK_NONBLOCK), Protocol: int64(protocol)}}}, false /* ignoreResult */) <-c @@ -841,7 +840,7 @@ type socketProvider struct { } // Socket implements socket.Provider.Socket. -func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Check that we are using the RPC network stack. stack := t.NetworkContext() if stack == nil { @@ -857,7 +856,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p // // Try to restrict the flags we will accept to minimize backwards // incompatibility with netstack. - stype := int(stypeflags) & linux.SOCK_TYPE_MASK + stype := stypeflags & linux.SOCK_TYPE_MASK switch stype { case syscall.SOCK_STREAM: switch protocol { @@ -881,7 +880,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p } // Pair implements socket.Provider.Pair. -func (p *socketProvider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { // Not supported by AF_INET/AF_INET6. return nil, nil, nil } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index a99423365..f1021ec67 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -130,12 +130,12 @@ type Provider interface { // If a nil Socket _and_ a nil error is returned, it means that the // protocol is not supported. A non-nil error should only be returned // if the protocol is supported, but an error occurs during creation. - Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) + Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) // Pair creates a pair of connected sockets. // // See Socket for error information. - Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) + Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) } // families holds a map of all known address families and their providers. @@ -149,7 +149,7 @@ func RegisterProvider(family int, provider Provider) { } // New creates a new socket with the given family, type and protocol. -func New(t *kernel.Task, family int, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func New(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { for _, p := range families[family] { s, err := p.Socket(t, stype, protocol) if err != nil { @@ -166,7 +166,7 @@ func New(t *kernel.Task, family int, stype transport.SockType, protocol int) (*f // Pair creates a new connected socket pair with the given family, type and // protocol. -func Pair(t *kernel.Task, family int, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func Pair(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { providers, ok := families[family] if !ok { return nil, nil, syserr.ErrAddressFamilyNotSupported diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 9c8ec0365..db79ac904 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -45,7 +45,7 @@ type ConnectingEndpoint interface { // Type returns the socket type, typically either SockStream or // SockSeqpacket. The connection attempt must be aborted if this // value doesn't match the ConnectableEndpoint's type. - Type() SockType + Type() linux.SockType // GetLocalAddress returns the bound path. GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) @@ -101,7 +101,7 @@ type connectionedEndpoint struct { // stype is used by connecting sockets to ensure that they are the // same type. The value is typically either tcpip.SockSeqpacket or // tcpip.SockStream. - stype SockType + stype linux.SockType // acceptedChan is per the TCP endpoint implementation. Note that the // sockets in this channel are _already in the connected state_, and @@ -112,7 +112,7 @@ type connectionedEndpoint struct { } // NewConnectioned creates a new unbound connectionedEndpoint. -func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint { +func NewConnectioned(stype linux.SockType, uid UniqueIDProvider) Endpoint { return &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), @@ -122,7 +122,7 @@ func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint { } // NewPair allocates a new pair of connected unix-domain connectionedEndpoints. -func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { +func NewPair(stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { a := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), @@ -139,7 +139,7 @@ func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit} - if stype == SockStream { + if stype == linux.SOCK_STREAM { a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}} b.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q2}} } else { @@ -163,7 +163,7 @@ func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { // NewExternal creates a new externally backed Endpoint. It behaves like a // socketpair. -func NewExternal(stype SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { +func NewExternal(stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { return &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected}, id: uid.UniqueID(), @@ -178,7 +178,7 @@ func (e *connectionedEndpoint) ID() uint64 { } // Type implements ConnectingEndpoint.Type and Endpoint.Type. -func (e *connectionedEndpoint) Type() SockType { +func (e *connectionedEndpoint) Type() linux.SockType { return e.stype } @@ -294,7 +294,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur } writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit} - if e.stype == SockStream { + if e.stype == linux.SOCK_STREAM { ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}} } else { ne.receiver = &queueReceiver{readQueue: writeQueue} @@ -309,7 +309,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur writeQueue: writeQueue, } readQueue.IncRef() - if e.stype == SockStream { + if e.stype == linux.SOCK_STREAM { returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected) } else { returnConnect(&queueReceiver{readQueue: readQueue}, connected) @@ -429,7 +429,7 @@ func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syser func (e *connectionedEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) { // Stream sockets do not support specifying the endpoint. Seqpacket // sockets ignore the passed endpoint. - if e.stype == SockStream && to != nil { + if e.stype == linux.SOCK_STREAM && to != nil { return 0, syserr.ErrNotSupported } return e.baseEndpoint.SendMsg(data, c, to) diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index c034cf984..81ebfba10 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -119,8 +119,8 @@ func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to Bo } // Type implements Endpoint.Type. -func (e *connectionlessEndpoint) Type() SockType { - return SockDgram +func (e *connectionlessEndpoint) Type() linux.SockType { + return linux.SOCK_DGRAM } // Connect attempts to connect directly to server. diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 5fc09af55..5c55c529e 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -19,6 +19,7 @@ import ( "sync" "sync/atomic" + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" @@ -28,21 +29,6 @@ import ( // initialLimit is the starting limit for the socket buffers. const initialLimit = 16 * 1024 -// A SockType is a type (as opposed to family) of sockets. These are enumerated -// in the syscall package as syscall.SOCK_* constants. -type SockType int - -const ( - // SockStream corresponds to syscall.SOCK_STREAM. - SockStream SockType = 1 - // SockDgram corresponds to syscall.SOCK_DGRAM. - SockDgram SockType = 2 - // SockRaw corresponds to syscall.SOCK_RAW. - SockRaw SockType = 3 - // SockSeqpacket corresponds to syscall.SOCK_SEQPACKET. - SockSeqpacket SockType = 5 -) - // A RightsControlMessage is a control message containing FDs. type RightsControlMessage interface { // Clone returns a copy of the RightsControlMessage. @@ -175,7 +161,7 @@ type Endpoint interface { // Type return the socket type, typically either SockStream, SockDgram // or SockSeqpacket. - Type() SockType + Type() linux.SockType // GetLocalAddress returns the address to which the endpoint is bound. GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) @@ -629,7 +615,7 @@ type connectedEndpoint struct { GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) // Type implements Endpoint.Type. - Type() SockType + Type() linux.SockType } writeQueue *queue @@ -653,7 +639,7 @@ func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, } truncate := false - if e.endpoint.Type() == SockStream { + if e.endpoint.Type() == linux.SOCK_STREAM { // Since stream sockets don't preserve message boundaries, we // can write only as much of the message as fits in the queue. truncate = true diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 375542350..56ed63e21 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -605,7 +605,7 @@ func (s *SocketOperations) State() uint32 { type provider struct{} // Socket returns a new unix domain socket. -func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Check arguments. if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ { return nil, syserr.ErrProtocolNotSupported @@ -631,7 +631,7 @@ func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) } // Pair creates a new pair of AF_UNIX connected sockets. -func (*provider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { // Check arguments. if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ { return nil, nil, syserr.ErrProtocolNotSupported diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index dbe53b9a2..0b5ef84c4 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -76,13 +76,13 @@ var SocketFamily = abi.ValueSet{ // SocketType are the possible socket(2) types. var SocketType = abi.ValueSet{ - linux.SOCK_STREAM: "SOCK_STREAM", - linux.SOCK_DGRAM: "SOCK_DGRAM", - linux.SOCK_RAW: "SOCK_RAW", - linux.SOCK_RDM: "SOCK_RDM", - linux.SOCK_SEQPACKET: "SOCK_SEQPACKET", - linux.SOCK_DCCP: "SOCK_DCCP", - linux.SOCK_PACKET: "SOCK_PACKET", + uint64(linux.SOCK_STREAM): "SOCK_STREAM", + uint64(linux.SOCK_DGRAM): "SOCK_DGRAM", + uint64(linux.SOCK_RAW): "SOCK_RAW", + uint64(linux.SOCK_RDM): "SOCK_RDM", + uint64(linux.SOCK_SEQPACKET): "SOCK_SEQPACKET", + uint64(linux.SOCK_DCCP): "SOCK_DCCP", + uint64(linux.SOCK_PACKET): "SOCK_PACKET", } // SocketFlagSet are the possible socket(2) flags. diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 8f4dbf3bc..31295a6a9 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -188,7 +188,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // Create the new socket. - s, e := socket.New(t, domain, transport.SockType(stype&0xf), protocol) + s, e := socket.New(t, domain, linux.SockType(stype&0xf), protocol) if e != nil { return 0, nil, e.ToError() } @@ -227,7 +227,7 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } // Create the socket pair. - s1, s2, e := socket.Pair(t, domain, transport.SockType(stype&0xf), protocol) + s1, s2, e := socket.Pair(t, domain, linux.SockType(stype&0xf), protocol) if e != nil { return 0, nil, e.ToError() } -- cgit v1.2.3