From d3ed9baac0dc967eaf6d3e3f986cafe60604121a Mon Sep 17 00:00:00 2001 From: Michael Pratt <mpratt@google.com> Date: Wed, 5 Jun 2019 13:59:01 -0700 Subject: Implement dumpability tracking and checks We don't actually support core dumps, but some applications want to get/set dumpability, which still has an effect in procfs. Lack of support for set-uid binaries or fs creds simplifies things a bit. As-is, processes started via CreateProcess (i.e., init and sentryctl exec) have normal dumpability. I'm a bit torn on whether sentryctl exec tasks should be dumpable, but at least since they have no parent normal UID/GID checks should protect them. PiperOrigin-RevId: 251712714 --- pkg/sentry/fs/proc/inode.go | 40 ++++++++++++++++++++++++++++++++++++++-- pkg/sentry/fs/proc/task.go | 17 ++++++++++++++++- 2 files changed, 54 insertions(+), 3 deletions(-) (limited to 'pkg/sentry/fs/proc') diff --git a/pkg/sentry/fs/proc/inode.go b/pkg/sentry/fs/proc/inode.go index 379569823..986bc0a45 100644 --- a/pkg/sentry/fs/proc/inode.go +++ b/pkg/sentry/fs/proc/inode.go @@ -21,11 +21,14 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil" "gvisor.googlesource.com/gvisor/pkg/sentry/fs/proc/device" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" + "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" + "gvisor.googlesource.com/gvisor/pkg/sentry/mm" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" ) // taskOwnedInodeOps wraps an fs.InodeOperations and overrides the UnstableAttr -// method to return the task as the owner. +// method to return either the task or root as the owner, depending on the +// task's dumpability. // // +stateify savable type taskOwnedInodeOps struct { @@ -41,9 +44,42 @@ func (i *taskOwnedInodeOps) UnstableAttr(ctx context.Context, inode *fs.Inode) ( if err != nil { return fs.UnstableAttr{}, err } - // Set the task owner as the file owner. + + // By default, set the task owner as the file owner. creds := i.t.Credentials() uattr.Owner = fs.FileOwner{creds.EffectiveKUID, creds.EffectiveKGID} + + // Linux doesn't apply dumpability adjustments to world + // readable/executable directories so that applications can stat + // /proc/PID to determine the effective UID of a process. See + // fs/proc/base.c:task_dump_owner. + if fs.IsDir(inode.StableAttr) && uattr.Perms == fs.FilePermsFromMode(0555) { + return uattr, nil + } + + // If the task is not dumpable, then root (in the namespace preferred) + // owns the file. + var m *mm.MemoryManager + i.t.WithMuLocked(func(t *kernel.Task) { + m = t.MemoryManager() + }) + + if m == nil { + uattr.Owner.UID = auth.RootKUID + uattr.Owner.GID = auth.RootKGID + } else if m.Dumpability() != mm.UserDumpable { + if kuid := creds.UserNamespace.MapToKUID(auth.RootUID); kuid.Ok() { + uattr.Owner.UID = kuid + } else { + uattr.Owner.UID = auth.RootKUID + } + if kgid := creds.UserNamespace.MapToKGID(auth.RootGID); kgid.Ok() { + uattr.Owner.GID = kgid + } else { + uattr.Owner.GID = auth.RootKGID + } + } + return uattr, nil } diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 77e03d349..21a965f90 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -96,7 +96,7 @@ func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, showSubtasks boo contents["cgroup"] = newCGroupInode(t, msrc, p.cgroupControllers) } - // TODO(b/31916171): Set EUID/EGID based on dumpability. + // N.B. taskOwnedInodeOps enforces dumpability-based ownership. d := &taskDir{ Dir: *ramfs.NewDir(t, contents, fs.RootOwner, fs.FilePermsFromMode(0555)), t: t, @@ -667,6 +667,21 @@ func newComm(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { return newProcInode(c, msrc, fs.SpecialFile, t) } +// Check implements fs.InodeOperations.Check. +func (c *comm) Check(ctx context.Context, inode *fs.Inode, p fs.PermMask) bool { + // This file can always be read or written by members of the same + // thread group. See fs/proc/base.c:proc_tid_comm_permission. + // + // N.B. This check is currently a no-op as we don't yet support writing + // and this file is world-readable anyways. + t := kernel.TaskFromContext(ctx) + if t != nil && t.ThreadGroup() == c.t.ThreadGroup() && !p.Execute { + return true + } + + return fs.ContextCanAccessFile(ctx, inode, p) +} + // GetFile implements fs.InodeOperations.GetFile. func (c *comm) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { return fs.NewFile(ctx, dirent, flags, &commFile{t: c.t}), nil -- cgit v1.2.3 From 2d2831e3541c8ae3c84f17cfd1bf0a26f2027044 Mon Sep 17 00:00:00 2001 From: Rahat Mahmood <rahat@google.com> Date: Thu, 6 Jun 2019 15:03:44 -0700 Subject: Track and export socket state. This is necessary for implementing network diagnostic interfaces like /proc/net/{tcp,udp,unix} and sock_diag(7). For pass-through endpoints such as hostinet, we obtain the socket state from the backend. For netstack, we add explicit tracking of TCP states. PiperOrigin-RevId: 251934850 --- pkg/abi/linux/socket.go | 16 ++ pkg/sentry/fs/proc/net.go | 20 +-- pkg/sentry/socket/epsocket/epsocket.go | 44 +++++ pkg/sentry/socket/hostinet/socket.go | 24 +++ pkg/sentry/socket/netlink/socket.go | 5 + pkg/sentry/socket/rpcinet/socket.go | 6 + pkg/sentry/socket/socket.go | 4 + pkg/sentry/socket/unix/transport/BUILD | 1 + pkg/sentry/socket/unix/transport/connectioned.go | 9 ++ pkg/sentry/socket/unix/transport/connectionless.go | 16 ++ pkg/sentry/socket/unix/transport/unix.go | 4 + pkg/sentry/socket/unix/unix.go | 5 + pkg/tcpip/stack/transport_test.go | 4 + pkg/tcpip/tcpip.go | 4 + pkg/tcpip/transport/icmp/BUILD | 1 + pkg/tcpip/transport/icmp/endpoint.go | 6 + pkg/tcpip/transport/raw/endpoint.go | 5 + pkg/tcpip/transport/tcp/accept.go | 12 +- pkg/tcpip/transport/tcp/connect.go | 26 ++- pkg/tcpip/transport/tcp/endpoint.go | 174 ++++++++++++++------ pkg/tcpip/transport/tcp/endpoint_state.go | 42 ++--- pkg/tcpip/transport/tcp/rcv.go | 37 +++++ pkg/tcpip/transport/tcp/snd.go | 4 + pkg/tcpip/transport/tcp/tcp_test.go | 131 ++++++++++++--- pkg/tcpip/transport/tcp/testing/context/context.go | 39 ++++- pkg/tcpip/transport/udp/endpoint.go | 6 + test/syscalls/linux/proc_net_unix.cc | 178 +++++++++++++++++++++ 27 files changed, 696 insertions(+), 127 deletions(-) (limited to 'pkg/sentry/fs/proc') diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 417840731..44bd69df6 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -200,6 +200,22 @@ const ( SS_DISCONNECTING = 4 // In process of disconnecting. ) +// TCP protocol states, from include/net/tcp_states.h. +const ( + TCP_ESTABLISHED uint32 = iota + 1 + TCP_SYN_SENT + TCP_SYN_RECV + TCP_FIN_WAIT1 + TCP_FIN_WAIT2 + TCP_TIME_WAIT + TCP_CLOSE + TCP_CLOSE_WAIT + TCP_LAST_ACK + TCP_LISTEN + TCP_CLOSING + TCP_NEW_SYN_RECV +) + // SockAddrMax is the maximum size of a struct sockaddr, from // uapi/linux/socket.h. const SockAddrMax = 128 diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 4a107c739..3daaa962c 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -240,24 +240,6 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s } } - var sockState int - switch sops.Endpoint().Type() { - case linux.SOCK_DGRAM: - sockState = linux.SS_CONNECTING - // Unlike Linux, we don't have unbound connection-less sockets, - // so no SS_DISCONNECTING. - - case linux.SOCK_SEQPACKET: - fallthrough - case linux.SOCK_STREAM: - // Connectioned. - if sops.Endpoint().(transport.ConnectingEndpoint).Connected() { - sockState = linux.SS_CONNECTED - } else { - sockState = linux.SS_UNCONNECTED - } - } - // In the socket entry below, the value for the 'Num' field requires // some consideration. Linux prints the address to the struct // unix_sock representing a socket in the kernel, but may redact the @@ -282,7 +264,7 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s 0, // Protocol, always 0 for UDS. sockFlags, // Flags. sops.Endpoint().Type(), // Type. - sockState, // State. + sops.State(), // State. sfile.InodeID(), // Inode. ) diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index de4b963da..f91c5127a 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -52,6 +52,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -2281,3 +2282,46 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 { } return rv } + +// State implements socket.Socket.State. State translates the internal state +// returned by netstack to values defined by Linux. +func (s *SocketOperations) State() uint32 { + if s.family != linux.AF_INET && s.family != linux.AF_INET6 { + // States not implemented for this socket's family. + return 0 + } + + if !s.isPacketBased() { + // TCP socket. + switch tcp.EndpointState(s.Endpoint.State()) { + case tcp.StateEstablished: + return linux.TCP_ESTABLISHED + case tcp.StateSynSent: + return linux.TCP_SYN_SENT + case tcp.StateSynRecv: + return linux.TCP_SYN_RECV + case tcp.StateFinWait1: + return linux.TCP_FIN_WAIT1 + case tcp.StateFinWait2: + return linux.TCP_FIN_WAIT2 + case tcp.StateTimeWait: + return linux.TCP_TIME_WAIT + case tcp.StateClose, tcp.StateInitial, tcp.StateBound, tcp.StateConnecting, tcp.StateError: + return linux.TCP_CLOSE + case tcp.StateCloseWait: + return linux.TCP_CLOSE_WAIT + case tcp.StateLastAck: + return linux.TCP_LAST_ACK + case tcp.StateListen: + return linux.TCP_LISTEN + case tcp.StateClosing: + return linux.TCP_CLOSING + default: + // Internal or unknown state. + return 0 + } + } + + // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets. + return 0 +} diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 41f9693bb..0d75580a3 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -19,7 +19,9 @@ import ( "syscall" "gvisor.googlesource.com/gvisor/pkg/abi/linux" + "gvisor.googlesource.com/gvisor/pkg/binary" "gvisor.googlesource.com/gvisor/pkg/fdnotifier" + "gvisor.googlesource.com/gvisor/pkg/log" "gvisor.googlesource.com/gvisor/pkg/sentry/context" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil" @@ -519,6 +521,28 @@ func translateIOSyscallError(err error) error { return err } +// State implements socket.Socket.State. +func (s *socketOperations) State() uint32 { + info := linux.TCPInfo{} + buf, err := getsockopt(s.fd, syscall.SOL_TCP, syscall.TCP_INFO, linux.SizeOfTCPInfo) + if err != nil { + if err != syscall.ENOPROTOOPT { + log.Warningf("Failed to get TCP socket info from %+v: %v", s, err) + } + // For non-TCP sockets, silently ignore the failure. + return 0 + } + if len(buf) != linux.SizeOfTCPInfo { + // Unmarshal below will panic if getsockopt returns a buffer of + // unexpected size. + log.Warningf("Failed to get TCP socket info from %+v: getsockopt(2) returned %d bytes, expecting %d bytes.", s, len(buf), linux.SizeOfTCPInfo) + return 0 + } + + binary.Unmarshal(buf, usermem.ByteOrder, &info) + return uint32(info.State) +} + type socketProvider struct { family int } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index afd06ca33..16c79aa33 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -616,3 +616,8 @@ func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{}) return int64(n), err.ToError() } + +// State implements socket.Socket.State. +func (s *Socket) State() uint32 { + return s.ep.State() +} diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 55e0b6665..bf42bdf69 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -830,6 +830,12 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] } } +// State implements socket.Socket.State. +func (s *socketOperations) State() uint32 { + // TODO(b/127845868): Define a new rpc to query the socket state. + return 0 +} + type socketProvider struct { family int } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 9393acd28..a99423365 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -116,6 +116,10 @@ type Socket interface { // SendTimeout gets the current timeout (in ns) for send operations. Zero // means no timeout, and negative means DONTWAIT. SendTimeout() int64 + + // State returns the current state of the socket, as represented by Linux in + // procfs. The returned state value is protocol-specific. + State() uint32 } // Provider is the interface implemented by providers of sockets for specific diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 5a2de0c4c..52f324eed 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -28,6 +28,7 @@ go_library( importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport", visibility = ["//:sandbox"], deps = [ + "//pkg/abi/linux", "//pkg/ilist", "//pkg/refs", "//pkg/syserr", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 18e492862..9c8ec0365 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -17,6 +17,7 @@ package transport import ( "sync" + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/waiter" @@ -458,3 +459,11 @@ func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask return ready } + +// State implements socket.Socket.State. +func (e *connectionedEndpoint) State() uint32 { + if e.Connected() { + return linux.SS_CONNECTED + } + return linux.SS_UNCONNECTED +} diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 43ff875e4..c034cf984 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -15,6 +15,7 @@ package transport import ( + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/waiter" @@ -194,3 +195,18 @@ func (e *connectionlessEndpoint) Readiness(mask waiter.EventMask) waiter.EventMa return ready } + +// State implements socket.Socket.State. +func (e *connectionlessEndpoint) State() uint32 { + e.Lock() + defer e.Unlock() + + switch { + case e.isBound(): + return linux.SS_UNCONNECTED + case e.Connected(): + return linux.SS_CONNECTING + default: + return linux.SS_DISCONNECTING + } +} diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 37d82bb6b..5fc09af55 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -191,6 +191,10 @@ type Endpoint interface { // GetSockOpt gets a socket option. opt should be a pointer to one of the // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error + + // State returns the current state of the socket, as represented by Linux in + // procfs. + State() uint32 } // A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 388cc0d8b..375542350 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -596,6 +596,11 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } } +// State implements socket.Socket.State. +func (s *SocketOperations) State() uint32 { + return s.ep.State() +} + // provider is a unix domain socket provider. type provider struct{} diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 8d74f1543..e8a9392b5 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -188,6 +188,10 @@ func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, s f.proto.controlCount++ } +func (f *fakeTransportEndpoint) State() uint32 { + return 0 +} + type fakeTransportGoodOption bool type fakeTransportBadOption bool diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index f9886c6e4..85ef014d0 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -377,6 +377,10 @@ type Endpoint interface { // GetSockOpt gets a socket option. opt should be a pointer to one of the // *Option types. GetSockOpt(opt interface{}) *Error + + // State returns a socket's lifecycle state. The returned value is + // protocol-specific and is primarily used for diagnostics. + State() uint32 } // WriteOptions contains options for Endpoint.Write. diff --git a/pkg/tcpip/transport/icmp/BUILD b/pkg/tcpip/transport/icmp/BUILD index 9aa6f3978..84a2b53b7 100644 --- a/pkg/tcpip/transport/icmp/BUILD +++ b/pkg/tcpip/transport/icmp/BUILD @@ -33,6 +33,7 @@ go_library( "//pkg/tcpip/header", "//pkg/tcpip/stack", "//pkg/tcpip/transport/raw", + "//pkg/tcpip/transport/tcp", "//pkg/waiter", ], ) diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index e2b90ef10..b8005093a 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -708,3 +708,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { } + +// State implements tcpip.Endpoint.State. The ICMP endpoint currently doesn't +// expose internal socket state. +func (e *endpoint) State() uint32 { + return 0 +} diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index 1daf5823f..e4ff50c91 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -519,3 +519,8 @@ func (ep *endpoint) HandlePacket(route *stack.Route, netHeader buffer.View, vv b ep.waiterQueue.Notify(waiter.EventIn) } } + +// State implements socket.Socket.State. +func (ep *endpoint) State() uint32 { + return 0 +} diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 31e365ae5..a32e20b06 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -226,7 +226,6 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i } n.isRegistered = true - n.state = stateConnecting // Create sender and receiver. // @@ -258,8 +257,9 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head ep.Close() return nil, err } - - ep.state = stateConnected + ep.mu.Lock() + ep.state = StateEstablished + ep.mu.Unlock() // Update the receive window scaling. We can't do it before the // handshake because it's possible that the peer doesn't support window @@ -276,7 +276,7 @@ func (e *endpoint) deliverAccepted(n *endpoint) { e.mu.RLock() state := e.state e.mu.RUnlock() - if state == stateListen { + if state == StateListen { e.acceptedChan <- n e.waiterQueue.Notify(waiter.EventIn) } else { @@ -406,7 +406,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.tsOffset = 0 // Switch state to connected. - n.state = stateConnected + n.state = StateEstablished // Do the delivery in a separate goroutine so // that we don't block the listen loop in case @@ -429,7 +429,7 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { // handleSynSegment() from attempting to queue new connections // to the endpoint. e.mu.Lock() - e.state = stateClosed + e.state = StateClose // Do cleanup if needed. e.completeWorkerLocked() diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 371d2ed29..0ad7bfb38 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -151,6 +151,9 @@ func (h *handshake) resetToSynRcvd(iss seqnum.Value, irs seqnum.Value, opts *hea h.mss = opts.MSS h.sndWndScale = opts.WS h.listenEP = listenEP + h.ep.mu.Lock() + h.ep.state = StateSynRecv + h.ep.mu.Unlock() } // checkAck checks if the ACK number, if present, of a segment received during @@ -219,6 +222,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { // but resend our own SYN and wait for it to be acknowledged in the // SYN-RCVD state. h.state = handshakeSynRcvd + h.ep.mu.Lock() + h.ep.state = StateSynRecv + h.ep.mu.Unlock() synOpts := header.TCPSynOptions{ WS: h.rcvWndScale, TS: rcvSynOpts.TS, @@ -668,7 +674,7 @@ func (e *endpoint) makeOptions(sackBlocks []header.SACKBlock) []byte { // sendRaw sends a TCP segment to the endpoint's peer. func (e *endpoint) sendRaw(data buffer.VectorisedView, flags byte, seq, ack seqnum.Value, rcvWnd seqnum.Size) *tcpip.Error { var sackBlocks []header.SACKBlock - if e.state == stateConnected && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { + if e.state == StateEstablished && e.rcv.pendingBufSize > 0 && (flags&header.TCPFlagAck != 0) { sackBlocks = e.sack.Blocks[:e.sack.NumBlocks] } options := e.makeOptions(sackBlocks) @@ -719,8 +725,7 @@ func (e *endpoint) handleClose() *tcpip.Error { // protocol goroutine. func (e *endpoint) resetConnectionLocked(err *tcpip.Error) { e.sendRaw(buffer.VectorisedView{}, header.TCPFlagAck|header.TCPFlagRst, e.snd.sndUna, e.rcv.rcvNxt, 0) - - e.state = stateError + e.state = StateError e.hardError = err } @@ -876,14 +881,19 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // handshake, and then inform potential waiters about its // completion. h := newHandshake(e, seqnum.Size(e.receiveBufferAvailable())) + e.mu.Lock() + h.ep.state = StateSynSent + e.mu.Unlock() + if err := h.execute(); err != nil { e.lastErrorMu.Lock() e.lastError = err e.lastErrorMu.Unlock() e.mu.Lock() - e.state = stateError + e.state = StateError e.hardError = err + // Lock released below. epilogue() @@ -905,7 +915,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Tell waiters that the endpoint is connected and writable. e.mu.Lock() - e.state = stateConnected + e.state = StateEstablished drained := e.drainDone != nil e.mu.Unlock() if drained { @@ -1005,7 +1015,7 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { return err } } - if e.state != stateError { + if e.state != StateError { close(e.drainDone) <-e.undrain } @@ -1061,8 +1071,8 @@ func (e *endpoint) protocolMainLoop(handshake bool) *tcpip.Error { // Mark endpoint as closed. e.mu.Lock() - if e.state != stateError { - e.state = stateClosed + if e.state != StateError { + e.state = StateClose } // Lock released below. epilogue() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index fd697402e..23422ca5e 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -32,18 +32,81 @@ import ( "gvisor.googlesource.com/gvisor/pkg/waiter" ) -type endpointState int +// EndpointState represents the state of a TCP endpoint. +type EndpointState uint32 +// Endpoint states. Note that are represented in a netstack-specific manner and +// may not be meaningful externally. Specifically, they need to be translated to +// Linux's representation for these states if presented to userspace. const ( - stateInitial endpointState = iota - stateBound - stateListen - stateConnecting - stateConnected - stateClosed - stateError + // Endpoint states internal to netstack. These map to the TCP state CLOSED. + StateInitial EndpointState = iota + StateBound + StateConnecting // Connect() called, but the initial SYN hasn't been sent. + StateError + + // TCP protocol states. + StateEstablished + StateSynSent + StateSynRecv + StateFinWait1 + StateFinWait2 + StateTimeWait + StateClose + StateCloseWait + StateLastAck + StateListen + StateClosing ) +// connected is the set of states where an endpoint is connected to a peer. +func (s EndpointState) connected() bool { + switch s { + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: + return true + default: + return false + } +} + +// String implements fmt.Stringer.String. +func (s EndpointState) String() string { + switch s { + case StateInitial: + return "INITIAL" + case StateBound: + return "BOUND" + case StateConnecting: + return "CONNECTING" + case StateError: + return "ERROR" + case StateEstablished: + return "ESTABLISHED" + case StateSynSent: + return "SYN-SENT" + case StateSynRecv: + return "SYN-RCVD" + case StateFinWait1: + return "FIN-WAIT1" + case StateFinWait2: + return "FIN-WAIT2" + case StateTimeWait: + return "TIME-WAIT" + case StateClose: + return "CLOSED" + case StateCloseWait: + return "CLOSE-WAIT" + case StateLastAck: + return "LAST-ACK" + case StateListen: + return "LISTEN" + case StateClosing: + return "CLOSING" + default: + panic("unreachable") + } +} + // Reasons for notifying the protocol goroutine. const ( notifyNonZeroReceiveWindow = 1 << iota @@ -108,10 +171,14 @@ type endpoint struct { rcvBufUsed int // The following fields are protected by the mutex. - mu sync.RWMutex `state:"nosave"` - id stack.TransportEndpointID - state endpointState `state:".(endpointState)"` - isPortReserved bool `state:"manual"` + mu sync.RWMutex `state:"nosave"` + id stack.TransportEndpointID + + // state endpointState `state:".(endpointState)"` + // pState ProtocolState + state EndpointState `state:".(EndpointState)"` + + isPortReserved bool `state:"manual"` isRegistered bool boundNICID tcpip.NICID `state:"manual"` route stack.Route `state:"manual"` @@ -304,6 +371,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite stack: stack, netProto: netProto, waiterQueue: waiterQueue, + state: StateInitial, rcvBufSize: DefaultBufferSize, sndBufSize: DefaultBufferSize, sndMTU: int(math.MaxInt32), @@ -351,14 +419,14 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { defer e.mu.RUnlock() switch e.state { - case stateInitial, stateBound, stateConnecting: + case StateInitial, StateBound, StateConnecting, StateSynSent, StateSynRecv: // Ready for nothing. - case stateClosed, stateError: + case StateClose, StateError: // Ready for anything. result = mask - case stateListen: + case StateListen: // Check if there's anything in the accepted channel. if (mask & waiter.EventIn) != 0 { if len(e.acceptedChan) > 0 { @@ -366,7 +434,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { } } - case stateConnected: + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: // Determine if the endpoint is writable if requested. if (mask & waiter.EventOut) != 0 { e.sndBufMu.Lock() @@ -427,7 +495,7 @@ func (e *endpoint) Close() { // are immediately available for reuse after Close() is called. If also // registered, we unregister as well otherwise the next user would fail // in Listen() when trying to register. - if e.state == stateListen && e.isPortReserved { + if e.state == StateListen && e.isPortReserved { if e.isRegistered { e.stack.UnregisterTransportEndpoint(e.boundNICID, e.effectiveNetProtos, ProtocolNumber, e.id, e) e.isRegistered = false @@ -487,15 +555,15 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, e.mu.RLock() // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. Also note that a RST being received - // would cause the state to become stateError so we should allow the + // would cause the state to become StateError so we should allow the // reads to proceed before returning a ECONNRESET. e.rcvListMu.Lock() bufUsed := e.rcvBufUsed - if s := e.state; s != stateConnected && s != stateClosed && bufUsed == 0 { + if s := e.state; !s.connected() && s != StateClose && bufUsed == 0 { e.rcvListMu.Unlock() he := e.hardError e.mu.RUnlock() - if s == stateError { + if s == StateError { return buffer.View{}, tcpip.ControlMessages{}, he } return buffer.View{}, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState @@ -511,7 +579,7 @@ func (e *endpoint) Read(*tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, func (e *endpoint) readLocked() (buffer.View, *tcpip.Error) { if e.rcvBufUsed == 0 { - if e.rcvClosed || e.state != stateConnected { + if e.rcvClosed || !e.state.connected() { return buffer.View{}, tcpip.ErrClosedForReceive } return buffer.View{}, tcpip.ErrWouldBlock @@ -547,9 +615,9 @@ func (e *endpoint) Write(p tcpip.Payload, opts tcpip.WriteOptions) (uintptr, <-c defer e.mu.RUnlock() // The endpoint cannot be written to if it's not connected. - if e.state != stateConnected { + if !e.state.connected() { switch e.state { - case stateError: + case StateError: return 0, nil, e.hardError default: return 0, nil, tcpip.ErrClosedForSend @@ -612,8 +680,8 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er // The endpoint can be read if it's connected, or if it's already closed // but has some pending unread data. - if s := e.state; s != stateConnected && s != stateClosed { - if s == stateError { + if s := e.state; !s.connected() && s != StateClose { + if s == StateError { return 0, tcpip.ControlMessages{}, e.hardError } return 0, tcpip.ControlMessages{}, tcpip.ErrInvalidEndpointState @@ -623,7 +691,7 @@ func (e *endpoint) Peek(vec [][]byte) (uintptr, tcpip.ControlMessages, *tcpip.Er defer e.rcvListMu.Unlock() if e.rcvBufUsed == 0 { - if e.rcvClosed || e.state != stateConnected { + if e.rcvClosed || !e.state.connected() { return 0, tcpip.ControlMessages{}, tcpip.ErrClosedForReceive } return 0, tcpip.ControlMessages{}, tcpip.ErrWouldBlock @@ -789,7 +857,7 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { defer e.mu.Unlock() // We only allow this to be set when we're in the initial state. - if e.state != stateInitial { + if e.state != StateInitial { return tcpip.ErrInvalidEndpointState } @@ -841,7 +909,7 @@ func (e *endpoint) readyReceiveSize() (int, *tcpip.Error) { defer e.mu.RUnlock() // The endpoint cannot be in listen state. - if e.state == stateListen { + if e.state == StateListen { return 0, tcpip.ErrInvalidEndpointState } @@ -1057,7 +1125,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er nicid := addr.NIC switch e.state { - case stateBound: + case StateBound: // If we're already bound to a NIC but the caller is requesting // that we use a different one now, we cannot proceed. if e.boundNICID == 0 { @@ -1070,16 +1138,16 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er nicid = e.boundNICID - case stateInitial: - // Nothing to do. We'll eventually fill-in the gaps in the ID - // (if any) when we find a route. + case StateInitial: + // Nothing to do. We'll eventually fill-in the gaps in the ID (if any) + // when we find a route. - case stateConnecting: - // A connection request has already been issued but hasn't - // completed yet. + case StateConnecting, StateSynSent, StateSynRecv: + // A connection request has already been issued but hasn't completed + // yet. return tcpip.ErrAlreadyConnecting - case stateConnected: + case StateEstablished: // The endpoint is already connected. If caller hasn't been notified yet, return success. if !e.isConnectNotified { e.isConnectNotified = true @@ -1088,7 +1156,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er // Otherwise return that it's already connected. return tcpip.ErrAlreadyConnected - case stateError: + case StateError: return e.hardError default: @@ -1154,7 +1222,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er } e.isRegistered = true - e.state = stateConnecting + e.state = StateConnecting e.route = r.Clone() e.boundNICID = nicid e.effectiveNetProtos = netProtos @@ -1175,7 +1243,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) (er } e.segmentQueue.mu.Unlock() e.snd.updateMaxPayloadSize(int(e.route.MTU()), 0) - e.state = stateConnected + e.state = StateEstablished } if run { @@ -1199,8 +1267,8 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { defer e.mu.Unlock() e.shutdownFlags |= flags - switch e.state { - case stateConnected: + switch { + case e.state.connected(): // Close for read. if (e.shutdownFlags & tcpip.ShutdownRead) != 0 { // Mark read side as closed. @@ -1241,7 +1309,7 @@ func (e *endpoint) Shutdown(flags tcpip.ShutdownFlags) *tcpip.Error { e.sndCloseWaker.Assert() } - case stateListen: + case e.state == StateListen: // Tell protocolListenLoop to stop. if flags&tcpip.ShutdownRead != 0 { e.notifyProtocolGoroutine(notifyClose) @@ -1269,7 +1337,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { // When the endpoint shuts down, it sets workerCleanup to true, and from // that point onward, acceptedChan is the responsibility of the cleanup() // method (and should not be touched anywhere else, including here). - if e.state == stateListen && !e.workerCleanup { + if e.state == StateListen && !e.workerCleanup { // Adjust the size of the channel iff we can fix existing // pending connections into the new one. if len(e.acceptedChan) > backlog { @@ -1288,7 +1356,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { } // Endpoint must be bound before it can transition to listen mode. - if e.state != stateBound { + if e.state != StateBound { return tcpip.ErrInvalidEndpointState } @@ -1298,7 +1366,7 @@ func (e *endpoint) Listen(backlog int) (err *tcpip.Error) { } e.isRegistered = true - e.state = stateListen + e.state = StateListen if e.acceptedChan == nil { e.acceptedChan = make(chan *endpoint, backlog) } @@ -1325,7 +1393,7 @@ func (e *endpoint) Accept() (tcpip.Endpoint, *waiter.Queue, *tcpip.Error) { defer e.mu.RUnlock() // Endpoint must be in listen state before it can accept connections. - if e.state != stateListen { + if e.state != StateListen { return nil, nil, tcpip.ErrInvalidEndpointState } @@ -1353,7 +1421,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { // Don't allow binding once endpoint is not in the initial state // anymore. This is because once the endpoint goes into a connected or // listen state, it is already bound. - if e.state != stateInitial { + if e.state != StateInitial { return tcpip.ErrAlreadyBound } @@ -1408,7 +1476,7 @@ func (e *endpoint) Bind(addr tcpip.FullAddress) (err *tcpip.Error) { } // Mark endpoint as bound. - e.state = stateBound + e.state = StateBound return nil } @@ -1430,7 +1498,7 @@ func (e *endpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { e.mu.RLock() defer e.mu.RUnlock() - if e.state != stateConnected { + if !e.state.connected() { return tcpip.FullAddress{}, tcpip.ErrNotConnected } @@ -1739,3 +1807,11 @@ func (e *endpoint) initGSO() { gso.MaxSize = e.route.GSOMaxSize() e.gso = gso } + +// State implements tcpip.Endpoint.State. It exports the endpoint's protocol +// state for diagnostics. +func (e *endpoint) State() uint32 { + e.mu.Lock() + defer e.mu.Unlock() + return uint32(e.state) +} diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index e8aed2875..5f30c2374 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -49,8 +49,8 @@ func (e *endpoint) beforeSave() { defer e.mu.Unlock() switch e.state { - case stateInitial, stateBound: - case stateConnected: + case StateInitial, StateBound: + case StateEstablished, StateSynSent, StateSynRecv, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.id.LocalAddress, e.id.LocalPort, e.id.RemoteAddress, e.id.RemotePort)}) @@ -66,17 +66,17 @@ func (e *endpoint) beforeSave() { break } fallthrough - case stateListen, stateConnecting: + case StateListen, StateConnecting: e.drainSegmentLocked() - if e.state != stateClosed && e.state != stateError { + if e.state != StateClose && e.state != StateError { if !e.workerRunning { panic("endpoint has no worker running in listen, connecting, or connected state") } break } fallthrough - case stateError, stateClosed: - for e.state == stateError && e.workerRunning { + case StateError, StateClose: + for e.state == StateError && e.workerRunning { e.mu.Unlock() time.Sleep(100 * time.Millisecond) e.mu.Lock() @@ -92,7 +92,7 @@ func (e *endpoint) beforeSave() { panic("endpoint still has waiters upon save") } - if e.state != stateClosed && !((e.state == stateBound || e.state == stateListen) == e.isPortReserved) { + if e.state != StateClose && !((e.state == StateBound || e.state == StateListen) == e.isPortReserved) { panic("endpoints which are not in the closed state must have a reserved port IFF they are in bound or listen state") } } @@ -132,7 +132,7 @@ func (e *endpoint) loadAcceptedChan(acceptedEndpoints []*endpoint) { } // saveState is invoked by stateify. -func (e *endpoint) saveState() endpointState { +func (e *endpoint) saveState() EndpointState { return e.state } @@ -146,15 +146,15 @@ var connectingLoading sync.WaitGroup // Bound endpoint loading happens last. // loadState is invoked by stateify. -func (e *endpoint) loadState(state endpointState) { +func (e *endpoint) loadState(state EndpointState) { // This is to ensure that the loading wait groups include all applicable // endpoints before any asynchronous calls to the Wait() methods. switch state { - case stateConnected: + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: connectedLoading.Add(1) - case stateListen: + case StateListen: listenLoading.Add(1) - case stateConnecting: + case StateConnecting, StateSynSent, StateSynRecv: connectingLoading.Add(1) } e.state = state @@ -168,7 +168,7 @@ func (e *endpoint) afterLoad() { state := e.state switch state { - case stateInitial, stateBound, stateListen, stateConnecting, stateConnected: + case StateInitial, StateBound, StateListen, StateConnecting, StateEstablished: var ss SendBufferSizeOption if err := e.stack.TransportProtocolOption(ProtocolNumber, &ss); err == nil { if e.sndBufSize < ss.Min || e.sndBufSize > ss.Max { @@ -181,7 +181,7 @@ func (e *endpoint) afterLoad() { } bind := func() { - e.state = stateInitial + e.state = StateInitial if len(e.bindAddress) == 0 { e.bindAddress = e.id.LocalAddress } @@ -191,7 +191,7 @@ func (e *endpoint) afterLoad() { } switch state { - case stateConnected: + case StateEstablished, StateFinWait1, StateFinWait2, StateTimeWait, StateCloseWait, StateLastAck, StateClosing: bind() if len(e.connectingAddress) == 0 { // This endpoint is accepted by netstack but not yet by @@ -211,7 +211,7 @@ func (e *endpoint) afterLoad() { panic("endpoint connecting failed: " + err.String()) } connectedLoading.Done() - case stateListen: + case StateListen: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -223,7 +223,7 @@ func (e *endpoint) afterLoad() { listenLoading.Done() tcpip.AsyncLoading.Done() }() - case stateConnecting: + case StateConnecting, StateSynSent, StateSynRecv: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -235,7 +235,7 @@ func (e *endpoint) afterLoad() { connectingLoading.Done() tcpip.AsyncLoading.Done() }() - case stateBound: + case StateBound: tcpip.AsyncLoading.Add(1) go func() { connectedLoading.Wait() @@ -244,7 +244,7 @@ func (e *endpoint) afterLoad() { bind() tcpip.AsyncLoading.Done() }() - case stateClosed: + case StateClose: if e.isPortReserved { tcpip.AsyncLoading.Add(1) go func() { @@ -252,12 +252,12 @@ func (e *endpoint) afterLoad() { listenLoading.Wait() connectingLoading.Wait() bind() - e.state = stateClosed + e.state = StateClose tcpip.AsyncLoading.Done() }() } fallthrough - case stateError: + case StateError: tcpip.DeleteDanglingEndpoint(e) } } diff --git a/pkg/tcpip/transport/tcp/rcv.go b/pkg/tcpip/transport/tcp/rcv.go index b08a0e356..f02fa6105 100644 --- a/pkg/tcpip/transport/tcp/rcv.go +++ b/pkg/tcpip/transport/tcp/rcv.go @@ -134,6 +134,7 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum // sequence numbers that have been consumed. TrimSACKBlockList(&r.ep.sack, r.rcvNxt) + // Handle FIN or FIN-ACK. if s.flagIsSet(header.TCPFlagFin) { r.rcvNxt++ @@ -144,6 +145,25 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum r.closed = true r.ep.readyToRead(nil) + // We just received a FIN, our next state depends on whether we sent a + // FIN already or not. + r.ep.mu.Lock() + switch r.ep.state { + case StateEstablished: + r.ep.state = StateCloseWait + case StateFinWait1: + if s.flagIsSet(header.TCPFlagAck) { + // FIN-ACK, transition to TIME-WAIT. + r.ep.state = StateTimeWait + } else { + // Simultaneous close, expecting a final ACK. + r.ep.state = StateClosing + } + case StateFinWait2: + r.ep.state = StateTimeWait + } + r.ep.mu.Unlock() + // Flush out any pending segments, except the very first one if // it happens to be the one we're handling now because the // caller is using it. @@ -156,6 +176,23 @@ func (r *receiver) consumeSegment(s *segment, segSeq seqnum.Value, segLen seqnum r.pendingRcvdSegments[i].decRef() } r.pendingRcvdSegments = r.pendingRcvdSegments[:first] + + return true + } + + // Handle ACK (not FIN-ACK, which we handled above) during one of the + // shutdown states. + if s.flagIsSet(header.TCPFlagAck) { + r.ep.mu.Lock() + switch r.ep.state { + case StateFinWait1: + r.ep.state = StateFinWait2 + case StateClosing: + r.ep.state = StateTimeWait + case StateLastAck: + r.ep.state = StateClose + } + r.ep.mu.Unlock() } return true diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 3464e4be7..b236d7af2 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -632,6 +632,10 @@ func (s *sender) maybeSendSegment(seg *segment, limit int, end seqnum.Value) (se } seg.flags = header.TCPFlagAck | header.TCPFlagFin segEnd = seg.sequenceNumber.Add(1) + // Transition to FIN-WAIT1 state since we're initiating an active close. + s.ep.mu.Lock() + s.ep.state = StateFinWait1 + s.ep.mu.Unlock() } else { // We're sending a non-FIN segment. if seg.flags&header.TCPFlagFin != 0 { diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index b8f0ccaf1..56b490aaa 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -168,8 +168,8 @@ func TestTCPResetsSentIncrement(t *testing.T) { // Receive the SYN-ACK reply. b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) ackHeaders := &context.Headers{ SrcPort: context.TestPort, @@ -269,8 +269,8 @@ func TestConnectResetAfterClose(t *testing.T) { time.Sleep(3 * time.Second) for { b := c.GetPacket() - tcp := header.TCP(header.IPv4(b).Payload()) - if tcp.Flags() == header.TCPFlagAck|header.TCPFlagFin { + tcpHdr := header.TCP(header.IPv4(b).Payload()) + if tcpHdr.Flags() == header.TCPFlagAck|header.TCPFlagFin { // This is a retransmit of the FIN, ignore it. continue } @@ -553,9 +553,13 @@ func TestRstOnCloseWithUnreadData(t *testing.T) { // We shouldn't consume a sequence number on RST. checker.SeqNum(uint32(c.IRS)+1), )) + // The RST puts the endpoint into an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } - // This final should be ignored because an ACK on a reset doesn't - // mean anything. + // This final ACK should be ignored because an ACK on a reset doesn't mean + // anything. c.SendPacket(nil, &context.Headers{ SrcPort: context.TestPort, DstPort: c.Port, @@ -618,6 +622,10 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { checker.SeqNum(uint32(c.IRS)+1), )) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Cause a RST to be generated by closing the read end now since we have // unread data. c.EP.Shutdown(tcpip.ShutdownRead) @@ -630,6 +638,10 @@ func TestRstOnCloseWithUnreadDataFinConvertRst(t *testing.T) { // We shouldn't consume a sequence number on RST. checker.SeqNum(uint32(c.IRS)+1), )) + // The RST puts the endpoint into an error state. + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateError; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } // The ACK to the FIN should now be rejected since the connection has been // closed by a RST. @@ -1510,8 +1522,8 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { for bytesReceived != dataLen { b := c.GetPacket() numPackets++ - tcp := header.TCP(header.IPv4(b).Payload()) - payloadLen := len(tcp.Payload()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + payloadLen := len(tcpHdr.Payload()) checker.IPv4(t, b, checker.TCP( checker.DstPort(context.TestPort), @@ -1522,7 +1534,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { ) pdata := data[bytesReceived : bytesReceived+payloadLen] - if p := tcp.Payload(); !bytes.Equal(pdata, p) { + if p := tcpHdr.Payload(); !bytes.Equal(pdata, p) { t.Fatalf("got data = %v, want = %v", p, pdata) } bytesReceived += payloadLen @@ -1530,7 +1542,7 @@ func testBrokenUpWrite(t *testing.T, c *context.Context, maxPayload int) { if c.TimeStampEnabled { // If timestamp option is enabled, echo back the timestamp and increment // the TSEcr value included in the packet and send that back as the TSVal. - parsedOpts := tcp.ParsedOptions() + parsedOpts := tcpHdr.ParsedOptions() tsOpt := [12]byte{header.TCPOptionNOP, header.TCPOptionNOP} header.EncodeTSOption(parsedOpts.TSEcr+1, parsedOpts.TSVal, tsOpt[2:]) options = tsOpt[:] @@ -1757,8 +1769,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { ), ) - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) // Wait for retransmit. time.Sleep(1 * time.Second) @@ -1766,8 +1778,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { checker.TCP( checker.DstPort(context.TestPort), checker.TCPFlags(header.TCPFlagSyn), - checker.SrcPort(tcp.SourcePort()), - checker.SeqNum(tcp.SequenceNumber()), + checker.SrcPort(tcpHdr.SourcePort()), + checker.SeqNum(tcpHdr.SequenceNumber()), checker.TCPSynOptions(header.TCPSynOptions{MSS: mss, WS: wndScale}), ), ) @@ -1775,8 +1787,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // Send SYN-ACK. iss := seqnum.Value(789) c.SendPacket(nil, &context.Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), Flags: header.TCPFlagSyn | header.TCPFlagAck, SeqNum: iss, AckNum: c.IRS.Add(1), @@ -2523,8 +2535,8 @@ func TestReceivedSegmentQueuing(t *testing.T) { checker.TCPFlags(header.TCPFlagAck), ), ) - tcp := header.TCP(header.IPv4(b).Payload()) - ack := seqnum.Value(tcp.AckNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + ack := seqnum.Value(tcpHdr.AckNumber()) if ack == last { break } @@ -2568,6 +2580,10 @@ func TestReadAfterClosedState(t *testing.T) { ), ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateFinWait1; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Send some data and acknowledge the FIN. data := []byte{1, 2, 3} c.SendPacket(data, &context.Headers{ @@ -2589,9 +2605,15 @@ func TestReadAfterClosedState(t *testing.T) { ), ) - // Give the stack the chance to transition to closed state. + // Give the stack the chance to transition to closed state. Note that since + // both the sender and receiver are now closed, we effectively skip the + // TIME-WAIT state. time.Sleep(1 * time.Second) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateClose; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Wait for receive to be notified. select { case <-ch: @@ -3680,9 +3702,15 @@ func TestPassiveConnectionAttemptIncrement(t *testing.T) { if err := ep.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { t.Fatalf("Bind failed: %v", err) } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } if err := c.EP.Listen(1); err != nil { t.Fatalf("Listen failed: %v", err) } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateListen; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } stats := c.Stack().Stats() want := stats.TCP.PassiveConnectionOpenings.Value() + 1 @@ -3826,3 +3854,68 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { } } } + +func TestEndpointBindListenAcceptState(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + wq := &waiter.Queue{} + ep, err := c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + if err := ep.Bind(tcpip.FullAddress{Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + if err := ep.Listen(10); err != nil { + t.Fatalf("Listen failed: %v", err) + } + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + c.PassiveConnectWithOptions(100, 5, header.TCPSynOptions{MSS: defaultIPv4MSS}) + + // Try to accept the connection. + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + defer wq.EventUnregister(&we) + + aep, _, err := ep.Accept() + if err == tcpip.ErrWouldBlock { + // Wait for connection to be established. + select { + case <-ch: + aep, _, err = ep.Accept() + if err != nil { + t.Fatalf("Accept failed: %v", err) + } + + case <-time.After(1 * time.Second): + t.Fatalf("Timed out waiting for accept") + } + } + if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Listening endpoint remains in listen state. + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + + ep.Close() + // Give worker goroutines time to receive the close notification. + time.Sleep(1 * time.Second) + if got, want := tcp.EndpointState(ep.State()), tcp.StateClose; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + // Accepted endpoint remains open when the listen endpoint is closed. + if got, want := tcp.EndpointState(aep.State()), tcp.StateEstablished; got != want { + t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + +} diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 6e12413c6..69a43b6f4 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -532,6 +532,9 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. if err != nil { c.t.Fatalf("NewEndpoint failed: %v", err) } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } if epRcvBuf != nil { if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { @@ -557,13 +560,16 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. checker.TCPFlags(header.TCPFlagSyn), ), ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } - tcp := header.TCP(header.IPv4(b).Payload()) - c.IRS = seqnum.Value(tcp.SequenceNumber()) + tcpHdr := header.TCP(header.IPv4(b).Payload()) + c.IRS = seqnum.Value(tcpHdr.SequenceNumber()) c.SendPacket(nil, &Headers{ - SrcPort: tcp.DestinationPort(), - DstPort: tcp.SourcePort(), + SrcPort: tcpHdr.DestinationPort(), + DstPort: tcpHdr.SourcePort(), Flags: header.TCPFlagSyn | header.TCPFlagAck, SeqNum: iss, AckNum: c.IRS.Add(1), @@ -591,8 +597,11 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. case <-time.After(1 * time.Second): c.t.Fatalf("Timed out waiting for connection") } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } - c.Port = tcp.SourcePort() + c.Port = tcpHdr.SourcePort() } // RawEndpoint is just a small wrapper around a TCP endpoint's state to make @@ -690,6 +699,9 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * if err != nil { c.t.Fatalf("c.s.NewEndpoint(tcp, ipv4...) = %v", err) } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } // Start connection attempt. waitEntry, notifyCh := waiter.NewChannelEntry(nil) @@ -719,6 +731,10 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * }), ), ) + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateSynSent; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } + tcpSeg := header.TCP(header.IPv4(b).Payload()) synOptions := header.ParseSynOptions(tcpSeg.Options(), false) @@ -782,6 +798,9 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * case <-time.After(1 * time.Second): c.t.Fatalf("Timed out waiting for connection") } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Fatalf("Unexpected endpoint state: want %v, got %v", want, got) + } // Store the source port in use by the endpoint. c.Port = tcpSeg.SourcePort() @@ -821,10 +840,16 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption if err := ep.Bind(tcpip.FullAddress{Port: StackPort}); err != nil { c.t.Fatalf("Bind failed: %v", err) } + if got, want := tcp.EndpointState(ep.State()), tcp.StateBound; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } if err := ep.Listen(10); err != nil { c.t.Fatalf("Listen failed: %v", err) } + if got, want := tcp.EndpointState(ep.State()), tcp.StateListen; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } rep := c.PassiveConnectWithOptions(100, wndScale, synOptions) @@ -847,6 +872,10 @@ func (c *Context) AcceptWithOptions(wndScale int, synOptions header.TCPSynOption c.t.Fatalf("Timed out waiting for accept") } } + if got, want := tcp.EndpointState(c.EP.State()), tcp.StateEstablished; got != want { + c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) + } + return rep } diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 3d52a4f31..fa7278286 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -1000,3 +1000,9 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, vv // HandleControlPacket implements stack.TransportEndpoint.HandleControlPacket. func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, vv buffer.VectorisedView) { } + +// State implements socket.Socket.State. +func (e *endpoint) State() uint32 { + // TODO(b/112063468): Translate internal state to values returned by Linux. + return 0 +} diff --git a/test/syscalls/linux/proc_net_unix.cc b/test/syscalls/linux/proc_net_unix.cc index 6d745f728..82d325c17 100644 --- a/test/syscalls/linux/proc_net_unix.cc +++ b/test/syscalls/linux/proc_net_unix.cc @@ -34,6 +34,16 @@ using absl::StrFormat; constexpr char kProcNetUnixHeader[] = "Num RefCount Protocol Flags Type St Inode Path"; +// Possible values of the "st" field in a /proc/net/unix entry. Source: Linux +// kernel, include/uapi/linux/net.h. +enum { + SS_FREE = 0, // Not allocated + SS_UNCONNECTED, // Unconnected to any socket + SS_CONNECTING, // In process of connecting + SS_CONNECTED, // Connected to socket + SS_DISCONNECTING // In process of disconnecting +}; + // UnixEntry represents a single entry from /proc/net/unix. struct UnixEntry { uintptr_t addr; @@ -71,7 +81,12 @@ PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() { bool skipped_header = false; std::vector<UnixEntry> entries; std::vector<std::string> lines = absl::StrSplit(content, absl::ByAnyChar("\n")); + std::cerr << "<contents of /proc/net/unix>" << std::endl; for (std::string line : lines) { + // Emit the proc entry to the test output to provide context for the test + // results. + std::cerr << line << std::endl; + if (!skipped_header) { EXPECT_EQ(line, kProcNetUnixHeader); skipped_header = true; @@ -139,6 +154,7 @@ PosixErrorOr<std::vector<UnixEntry>> ProcNetUnixEntries() { entries.push_back(entry); } + std::cerr << "<end of /proc/net/unix>" << std::endl; return entries; } @@ -241,6 +257,168 @@ TEST(ProcNetUnix, SocketPair) { EXPECT_EQ(entries.size(), 2); } +TEST(ProcNetUnix, StreamSocketStateUnconnectedOnBind) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create()); + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + std::vector<UnixEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + const std::string address = ExtractPath(sockets->first_addr()); + UnixEntry bind_entry; + ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); + EXPECT_EQ(bind_entry.state, SS_UNCONNECTED); +} + +TEST(ProcNetUnix, StreamSocketStateStateUnconnectedOnListen) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create()); + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + std::vector<UnixEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + const std::string address = ExtractPath(sockets->first_addr()); + UnixEntry bind_entry; + ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); + EXPECT_EQ(bind_entry.state, SS_UNCONNECTED); + + ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + UnixEntry listen_entry; + ASSERT_TRUE( + FindByPath(entries, &listen_entry, ExtractPath(sockets->first_addr()))); + EXPECT_EQ(listen_entry.state, SS_UNCONNECTED); + // The bind and listen entries should refer to the same socket. + EXPECT_EQ(listen_entry.inode, bind_entry.inode); +} + +TEST(ProcNetUnix, StreamSocketStateStateConnectedOnAccept) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_STREAM).Create()); + const std::string address = ExtractPath(sockets->first_addr()); + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + ASSERT_THAT(listen(sockets->first_fd(), 5), SyscallSucceeds()); + std::vector<UnixEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + UnixEntry listen_entry; + ASSERT_TRUE( + FindByPath(entries, &listen_entry, ExtractPath(sockets->first_addr()))); + + ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + int clientfd; + ASSERT_THAT(clientfd = accept(sockets->first_fd(), nullptr, nullptr), + SyscallSucceeds()); + + // Find the entry for the accepted socket. UDS proc entries don't have a + // remote address, so we distinguish the accepted socket from the listen + // socket by checking for a different inode. + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + UnixEntry accept_entry; + ASSERT_TRUE(FindBy( + entries, &accept_entry, [address, listen_entry](const UnixEntry& e) { + return e.path == address && e.inode != listen_entry.inode; + })); + EXPECT_EQ(accept_entry.state, SS_CONNECTED); + // Listen entry should still be in SS_UNCONNECTED state. + ASSERT_TRUE(FindBy(entries, &listen_entry, + [&sockets, listen_entry](const UnixEntry& e) { + return e.path == ExtractPath(sockets->first_addr()) && + e.inode == listen_entry.inode; + })); + EXPECT_EQ(listen_entry.state, SS_UNCONNECTED); +} + +TEST(ProcNetUnix, DgramSocketStateDisconnectingOnBind) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_DGRAM).Create()); + + std::vector<UnixEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + // On gVisor, the only two UDS on the system are the ones we just created and + // we rely on this to locate the test socket entries in the remainder of the + // test. On a generic Linux system, we have no easy way to locate the + // corresponding entries, as they don't have an address yet. + if (IsRunningOnGvisor()) { + ASSERT_EQ(entries.size(), 2); + for (auto e : entries) { + ASSERT_EQ(e.state, SS_DISCONNECTING); + } + } + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + const std::string address = ExtractPath(sockets->first_addr()); + UnixEntry bind_entry; + ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); + EXPECT_EQ(bind_entry.state, SS_UNCONNECTED); +} + +TEST(ProcNetUnix, DgramSocketStateConnectingOnConnect) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE( + AbstractUnboundUnixDomainSocketPair(SOCK_DGRAM).Create()); + + std::vector<UnixEntry> entries = + ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + // On gVisor, the only two UDS on the system are the ones we just created and + // we rely on this to locate the test socket entries in the remainder of the + // test. On a generic Linux system, we have no easy way to locate the + // corresponding entries, as they don't have an address yet. + if (IsRunningOnGvisor()) { + ASSERT_EQ(entries.size(), 2); + for (auto e : entries) { + ASSERT_EQ(e.state, SS_DISCONNECTING); + } + } + + ASSERT_THAT(bind(sockets->first_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + const std::string address = ExtractPath(sockets->first_addr()); + UnixEntry bind_entry; + ASSERT_TRUE(FindByPath(entries, &bind_entry, address)); + + ASSERT_THAT(connect(sockets->second_fd(), sockets->first_addr(), + sockets->first_addr_size()), + SyscallSucceeds()); + + entries = ASSERT_NO_ERRNO_AND_VALUE(ProcNetUnixEntries()); + + // Once again, we have no easy way to identify the connecting socket as it has + // no listed address. We can only identify the entry as the "non-bind socket + // entry" on gVisor, where we're guaranteed to have only the two entries we + // create during this test. + if (IsRunningOnGvisor()) { + ASSERT_EQ(entries.size(), 2); + UnixEntry connect_entry; + ASSERT_TRUE( + FindBy(entries, &connect_entry, [bind_entry](const UnixEntry& e) { + return e.inode != bind_entry.inode; + })); + EXPECT_EQ(connect_entry.state, SS_CONNECTING); + } +} + } // namespace } // namespace testing } // namespace gvisor -- cgit v1.2.3 From a00157cc0e216a9829f2659ce35c856a22aa5ba2 Mon Sep 17 00:00:00 2001 From: Rahat Mahmood <rahat@google.com> Date: Mon, 10 Jun 2019 15:16:42 -0700 Subject: Store more information in the kernel socket table. Store enough information in the kernel socket table to distinguish between different types of sockets. Previously we were only storing the socket family, but this isn't enough to classify sockets. For example, TCPv4 and UDPv4 sockets are both AF_INET, and ICMP sockets are SOCK_DGRAM sockets with a particular protocol. Instead of creating more sub-tables, flatten the socket table and provide a filtering mechanism based on the socket entry. Also generate and store a socket entry index ("sl" in linux) which allows us to output entries in a stable order from procfs. PiperOrigin-RevId: 252495895 --- pkg/sentry/fs/host/socket.go | 4 +-- pkg/sentry/fs/proc/BUILD | 1 + pkg/sentry/fs/proc/net.go | 14 ++++---- pkg/sentry/kernel/BUILD | 13 +++++++ pkg/sentry/kernel/kernel.go | 55 +++++++++++++--------------- pkg/sentry/socket/epsocket/epsocket.go | 13 +++++-- pkg/sentry/socket/epsocket/provider.go | 2 +- pkg/sentry/socket/hostinet/socket.go | 32 +++++++++++------ pkg/sentry/socket/netlink/provider.go | 2 +- pkg/sentry/socket/netlink/socket.go | 12 ++++++- pkg/sentry/socket/rpcinet/socket.go | 16 +++++++-- pkg/sentry/socket/socket.go | 9 +++-- pkg/sentry/socket/unix/unix.go | 65 ++++++++++++++++++++-------------- 13 files changed, 152 insertions(+), 86 deletions(-) (limited to 'pkg/sentry/fs/proc') diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 6423ad938..305eea718 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -164,7 +164,7 @@ func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.F ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.NewWithDirent(ctx, d, ep, e.stype != linux.SOCK_STREAM, flags), nil + return unixsocket.NewWithDirent(ctx, d, ep, e.stype, flags), nil } // newSocket allocates a new unix socket with host endpoint. @@ -196,7 +196,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.New(ctx, ep, e.stype != linux.SOCK_STREAM), nil + return unixsocket.New(ctx, ep, e.stype), nil } // Send implements transport.ConnectedEndpoint.Send. diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index d19c360e0..1728fe0b5 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -45,6 +45,7 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/limits", "//pkg/sentry/mm", + "//pkg/sentry/socket", "//pkg/sentry/socket/rpcinet", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 3daaa962c..034950158 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -27,6 +27,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs" "gvisor.googlesource.com/gvisor/pkg/sentry/inet" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" ) @@ -213,17 +214,18 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s fmt.Fprintf(&buf, "Num RefCount Protocol Flags Type St Inode Path\n") // Entries - for _, sref := range n.k.ListSockets(linux.AF_UNIX) { - s := sref.Get() + for _, se := range n.k.ListSockets() { + s := se.Sock.Get() if s == nil { - log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", sref) + log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock) continue } sfile := s.(*fs.File) - sops, ok := sfile.FileOperations.(*unix.SocketOperations) - if !ok { - panic(fmt.Sprintf("Found non-unix socket file in unix socket table: %+v", sfile)) + if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX { + // Not a unix socket. + continue } + sops := sfile.FileOperations.(*unix.SocketOperations) addr, err := sops.Endpoint().GetLocalAddress() if err != nil { diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 99a2fd964..04e375910 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -64,6 +64,18 @@ go_template_instance( }, ) +go_template_instance( + name = "socket_list", + out = "socket_list.go", + package = "kernel", + prefix = "socket", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*SocketEntry", + "Linker": "*SocketEntry", + }, +) + proto_library( name = "uncaught_signal_proto", srcs = ["uncaught_signal.proto"], @@ -104,6 +116,7 @@ go_library( "sessions.go", "signal.go", "signal_handlers.go", + "socket_list.go", "syscalls.go", "syscalls_state.go", "syslog.go", diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 85d73ace2..f253a81d9 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -182,9 +182,13 @@ type Kernel struct { // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints. danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"` - // socketTable is used to track all sockets on the system. Protected by + // sockets is the list of all network sockets the system. Protected by // extMu. - socketTable map[int]map[*refs.WeakRef]struct{} + sockets socketList + + // nextSocketEntry is the next entry number to use in sockets. Protected + // by extMu. + nextSocketEntry uint64 // deviceRegistry is used to save/restore device.SimpleDevices. deviceRegistry struct{} `state:".(*device.Registry)"` @@ -283,7 +287,6 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic} k.futexes = futex.NewManager() k.netlinkPorts = port.New() - k.socketTable = make(map[int]map[*refs.WeakRef]struct{}) return nil } @@ -1137,51 +1140,43 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) { }) } -// socketEntry represents a socket recorded in Kernel.socketTable. It implements +// SocketEntry represents a socket recorded in Kernel.sockets. It implements // refs.WeakRefUser for sockets stored in the socket table. // // +stateify savable -type socketEntry struct { - k *Kernel - sock *refs.WeakRef - family int +type SocketEntry struct { + socketEntry + k *Kernel + Sock *refs.WeakRef + ID uint64 // Socket table entry number. } // WeakRefGone implements refs.WeakRefUser.WeakRefGone. -func (s *socketEntry) WeakRefGone() { +func (s *SocketEntry) WeakRefGone() { s.k.extMu.Lock() - // k.socketTable is guaranteed to point to a valid socket table for s.family - // at this point, since we made sure of the fact when we created this - // socketEntry, and we never delete socket tables. - delete(s.k.socketTable[s.family], s.sock) + s.k.sockets.Remove(s) s.k.extMu.Unlock() } // RecordSocket adds a socket to the system-wide socket table for tracking. // // Precondition: Caller must hold a reference to sock. -func (k *Kernel) RecordSocket(sock *fs.File, family int) { +func (k *Kernel) RecordSocket(sock *fs.File) { k.extMu.Lock() - table, ok := k.socketTable[family] - if !ok { - table = make(map[*refs.WeakRef]struct{}) - k.socketTable[family] = table - } - se := socketEntry{k: k, family: family} - se.sock = refs.NewWeakRef(sock, &se) - table[se.sock] = struct{}{} + id := k.nextSocketEntry + k.nextSocketEntry++ + s := &SocketEntry{k: k, ID: id} + s.Sock = refs.NewWeakRef(sock, s) + k.sockets.PushBack(s) k.extMu.Unlock() } -// ListSockets returns a snapshot of all sockets of a given family. -func (k *Kernel) ListSockets(family int) []*refs.WeakRef { +// ListSockets returns a snapshot of all sockets. +func (k *Kernel) ListSockets() []*SocketEntry { k.extMu.Lock() - socks := []*refs.WeakRef{} - if table, ok := k.socketTable[family]; ok { - socks = make([]*refs.WeakRef, 0, len(table)) - for s := range table { - socks = append(socks, s) - } + var socks []*SocketEntry + for s := k.sockets.Front(); s != nil; s = s.Next() { + socks = append(socks, s) } k.extMu.Unlock() return socks diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index e1e29de35..f67451179 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -228,6 +228,7 @@ type SocketOperations struct { family int Endpoint tcpip.Endpoint skType linux.SockType + protocol int // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` @@ -252,7 +253,7 @@ type SocketOperations struct { } // New creates a new endpoint socket. -func New(t *kernel.Task, family int, skType linux.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { +func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil { return nil, syserr.TranslateNetstackError(err) @@ -266,6 +267,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, queue *waiter.Queue, family: family, Endpoint: endpoint, skType: skType, + protocol: protocol, }), nil } @@ -550,7 +552,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } } - ns, err := New(t, s.family, s.skType, wq, ep) + ns, err := New(t, s.family, s.skType, s.protocol, wq, ep) if err != nil { return 0, nil, 0, err } @@ -578,7 +580,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits()) - t.Kernel().RecordSocket(ns, s.family) + t.Kernel().RecordSocket(ns) return fd, addr, addrLen, syserr.FromError(e) } @@ -2324,3 +2326,8 @@ func (s *SocketOperations) State() uint32 { // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets. return 0 } + +// Type implements socket.Socket.Type. +func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { + return s.family, s.skType, s.protocol +} diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go index e48a106ea..516582828 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/epsocket/provider.go @@ -111,7 +111,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* return nil, syserr.TranslateNetstackError(e) } - return New(t, p.family, stype, wq, ep) + return New(t, p.family, stype, protocol, wq, ep) } // Pair just returns nil sockets (not supported). diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 4517951a0..c62c8d8f1 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -56,15 +56,22 @@ type socketOperations struct { fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout - family int // Read-only. - fd int // must be O_NONBLOCK - queue waiter.Queue + family int // Read-only. + stype linux.SockType // Read-only. + protocol int // Read-only. + fd int // must be O_NONBLOCK + queue waiter.Queue } var _ = socket.Socket(&socketOperations{}) -func newSocketFile(ctx context.Context, family int, fd int, nonblock bool) (*fs.File, *syserr.Error) { - s := &socketOperations{family: family, fd: fd} +func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) { + s := &socketOperations{ + family: family, + stype: stype, + protocol: protocol, + fd: fd, + } if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil { return nil, syserr.FromError(err) } @@ -222,7 +229,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr) } - f, err := newSocketFile(t, s.family, fd, flags&syscall.SOCK_NONBLOCK != 0) + f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0) if err != nil { syscall.Close(fd) return 0, nil, 0, err @@ -233,7 +240,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, } kfd, kerr := t.FDMap().NewFDFrom(0, f, fdFlags, t.ThreadGroup().Limits()) - t.Kernel().RecordSocket(f, s.family) + t.Kernel().RecordSocket(f) return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr) } @@ -542,6 +549,11 @@ func (s *socketOperations) State() uint32 { return uint32(info.State) } +// Type implements socket.Socket.Type. +func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) { + return s.family, s.stype, s.protocol +} + type socketProvider struct { family int } @@ -558,7 +570,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto } // Only accept TCP and UDP. - stype := int(stypeflags) & linux.SOCK_TYPE_MASK + stype := stypeflags & linux.SOCK_TYPE_MASK switch stype { case syscall.SOCK_STREAM: switch protocol { @@ -581,11 +593,11 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto // Conservatively ignore all flags specified by the application and add // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0 // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent. - fd, err := syscall.Socket(p.family, stype|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) + fd, err := syscall.Socket(p.family, int(stype)|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { return nil, syserr.FromError(err) } - return newSocketFile(t, p.family, fd, stypeflags&syscall.SOCK_NONBLOCK != 0) + return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&syscall.SOCK_NONBLOCK != 0) } // Pair implements socket.Provider.Pair. diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 863edc241..5dc103877 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -82,7 +82,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int return nil, err } - s, err := NewSocket(t, p) + s, err := NewSocket(t, stype, p) if err != nil { return nil, err } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 16c79aa33..62659784a 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -80,6 +80,10 @@ type Socket struct { // protocol is the netlink protocol implementation. protocol Protocol + // skType is the socket type. This is either SOCK_DGRAM or SOCK_RAW for + // netlink sockets. + skType linux.SockType + // ep is a datagram unix endpoint used to buffer messages sent from the // kernel to userspace. RecvMsg reads messages from this endpoint. ep transport.Endpoint @@ -105,7 +109,7 @@ type Socket struct { var _ socket.Socket = (*Socket)(nil) // NewSocket creates a new Socket. -func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) { +func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) { // Datagram endpoint used to buffer kernel -> user messages. ep := transport.NewConnectionless() @@ -126,6 +130,7 @@ func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) { return &Socket{ ports: t.Kernel().NetlinkPorts(), protocol: protocol, + skType: skType, ep: ep, connection: connection, sendBufferSize: defaultSendBufferSize, @@ -621,3 +626,8 @@ func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, func (s *Socket) State() uint32 { return s.ep.State() } + +// Type implements socket.Socket.Type. +func (s *Socket) Type() (family int, skType linux.SockType, protocol int) { + return linux.AF_NETLINK, s.skType, s.protocol.Protocol() +} diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 2d5b5b58f..c22ff1ff0 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -53,7 +53,10 @@ type socketOperations struct { fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout - family int // Read-only. + family int // Read-only. + stype linux.SockType // Read-only. + protocol int // Read-only. + fd uint32 // must be O_NONBLOCK wq *waiter.Queue rpcConn *conn.RPCConnection @@ -86,6 +89,8 @@ func newSocketFile(ctx context.Context, stack *Stack, family int, skType linux.S defer dirent.DecRef() return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &socketOperations{ family: family, + stype: skType, + protocol: protocol, wq: &wq, fd: fd, rpcConn: stack.rpcConn, @@ -332,7 +337,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, if err != nil { return 0, nil, 0, syserr.FromError(err) } - t.Kernel().RecordSocket(file, s.family) + t.Kernel().RecordSocket(file) if peerRequested { return fd, payload.Address.Address, payload.Address.Length, nil @@ -835,6 +840,11 @@ func (s *socketOperations) State() uint32 { return 0 } +// Type implements socket.Socket.Type. +func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) { + return s.family, s.stype, s.protocol +} + type socketProvider struct { family int } @@ -876,7 +886,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto return nil, nil } - return newSocketFile(t, s, p.family, stype, 0) + return newSocketFile(t, s, p.family, stype, protocol) } // Pair implements socket.Provider.Pair. diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index f1021ec67..d60944b6b 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -120,6 +120,9 @@ type Socket interface { // State returns the current state of the socket, as represented by Linux in // procfs. The returned state value is protocol-specific. State() uint32 + + // Type returns the family, socket type and protocol of the socket. + Type() (family int, skType linux.SockType, protocol int) } // Provider is the interface implemented by providers of sockets for specific @@ -156,7 +159,7 @@ func New(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.Fi return nil, err } if s != nil { - t.Kernel().RecordSocket(s, family) + t.Kernel().RecordSocket(s) return s, nil } } @@ -179,8 +182,8 @@ func Pair(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.F } if s1 != nil && s2 != nil { k := t.Kernel() - k.RecordSocket(s1, family) - k.RecordSocket(s2, family) + k.RecordSocket(s1) + k.RecordSocket(s2) return s1, s2, nil } } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 56ed63e21..b07e8d67b 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -17,6 +17,7 @@ package unix import ( + "fmt" "strings" "syscall" @@ -55,22 +56,22 @@ type SocketOperations struct { refs.AtomicRefCount socket.SendReceiveTimeout - ep transport.Endpoint - isPacket bool + ep transport.Endpoint + stype linux.SockType } // New creates a new unix socket. -func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File { +func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File { dirent := socket.NewDirent(ctx, unixSocketDevice) defer dirent.DecRef() - return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true}) + return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true}) } // NewWithDirent creates a new unix socket using an existing dirent. -func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File { +func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, stype linux.SockType, flags fs.FileFlags) *fs.File { return fs.NewFile(ctx, d, flags, &SocketOperations{ - ep: ep, - isPacket: isPacket, + ep: ep, + stype: stype, }) } @@ -88,6 +89,18 @@ func (s *SocketOperations) Release() { s.DecRef() } +func (s *SocketOperations) isPacket() bool { + switch s.stype { + case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: + return true + case linux.SOCK_STREAM: + return false + default: + // We shouldn't have allowed any other socket types during creation. + panic(fmt.Sprintf("Invalid socket type %d", s.stype)) + } +} + // Endpoint extracts the transport.Endpoint. func (s *SocketOperations) Endpoint() transport.Endpoint { return s.ep @@ -193,7 +206,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } } - ns := New(t, ep, s.isPacket) + ns := New(t, ep, s.stype) defer ns.DecRef() if flags&linux.SOCK_NONBLOCK != 0 { @@ -221,7 +234,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, return 0, nil, 0, syserr.FromError(e) } - t.Kernel().RecordSocket(ns, linux.AF_UNIX) + t.Kernel().RecordSocket(ns) return fd, addr, addrLen, nil } @@ -487,6 +500,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 waitAll := flags&linux.MSG_WAITALL != 0 + isPacket := s.isPacket() // Calculate the number of FDs for which we have space and if we are // requesting credentials. @@ -528,8 +542,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags msgFlags |= linux.MSG_CTRUNC } - if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() { - if s.isPacket && n < int64(r.MsgSize) { + if err != nil || dontWait || !waitAll || isPacket || n >= dst.NumBytes() { + if isPacket && n < int64(r.MsgSize) { msgFlags |= linux.MSG_TRUNC } @@ -570,11 +584,11 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags total += n } - if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() { + if err != nil || !waitAll || isPacket || n >= dst.NumBytes() { if total > 0 { err = nil } - if s.isPacket && n < int64(r.MsgSize) { + if isPacket && n < int64(r.MsgSize) { msgFlags |= linux.MSG_TRUNC } return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) @@ -601,6 +615,12 @@ func (s *SocketOperations) State() uint32 { return s.ep.State() } +// Type implements socket.Socket.Type. +func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { + // Unix domain sockets always have a protocol of 0. + return linux.AF_UNIX, s.stype, 0 +} + // provider is a unix domain socket provider. type provider struct{} @@ -613,21 +633,16 @@ func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs // Create the endpoint and socket. var ep transport.Endpoint - var isPacket bool switch stype { case linux.SOCK_DGRAM: - isPacket = true ep = transport.NewConnectionless() - case linux.SOCK_SEQPACKET: - isPacket = true - fallthrough - case linux.SOCK_STREAM: + case linux.SOCK_SEQPACKET, linux.SOCK_STREAM: ep = transport.NewConnectioned(stype, t.Kernel()) default: return nil, syserr.ErrInvalidArgument } - return New(t, ep, isPacket), nil + return New(t, ep, stype), nil } // Pair creates a new pair of AF_UNIX connected sockets. @@ -637,19 +652,17 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F return nil, nil, syserr.ErrProtocolNotSupported } - var isPacket bool switch stype { - case linux.SOCK_STREAM: - case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: - isPacket = true + case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: + // Ok default: return nil, nil, syserr.ErrInvalidArgument } // Create the endpoints and sockets. ep1, ep2 := transport.NewPair(stype, t.Kernel()) - s1 := New(t, ep1, isPacket) - s2 := New(t, ep2, isPacket) + s1 := New(t, ep1, stype) + s2 := New(t, ep2, stype) return s1, s2, nil } -- cgit v1.2.3