From 2d2831e3541c8ae3c84f17cfd1bf0a26f2027044 Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Thu, 6 Jun 2019 15:03:44 -0700 Subject: Track and export socket state. This is necessary for implementing network diagnostic interfaces like /proc/net/{tcp,udp,unix} and sock_diag(7). For pass-through endpoints such as hostinet, we obtain the socket state from the backend. For netstack, we add explicit tracking of TCP states. PiperOrigin-RevId: 251934850 --- pkg/sentry/socket/epsocket/epsocket.go | 44 ++++++++++++++++++++++++++++++++++ 1 file changed, 44 insertions(+) (limited to 'pkg/sentry/socket/epsocket') diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index de4b963da..f91c5127a 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -52,6 +52,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" "gvisor.googlesource.com/gvisor/pkg/tcpip/stack" + "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp" "gvisor.googlesource.com/gvisor/pkg/waiter" ) @@ -2281,3 +2282,46 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 { } return rv } + +// State implements socket.Socket.State. State translates the internal state +// returned by netstack to values defined by Linux. +func (s *SocketOperations) State() uint32 { + if s.family != linux.AF_INET && s.family != linux.AF_INET6 { + // States not implemented for this socket's family. + return 0 + } + + if !s.isPacketBased() { + // TCP socket. + switch tcp.EndpointState(s.Endpoint.State()) { + case tcp.StateEstablished: + return linux.TCP_ESTABLISHED + case tcp.StateSynSent: + return linux.TCP_SYN_SENT + case tcp.StateSynRecv: + return linux.TCP_SYN_RECV + case tcp.StateFinWait1: + return linux.TCP_FIN_WAIT1 + case tcp.StateFinWait2: + return linux.TCP_FIN_WAIT2 + case tcp.StateTimeWait: + return linux.TCP_TIME_WAIT + case tcp.StateClose, tcp.StateInitial, tcp.StateBound, tcp.StateConnecting, tcp.StateError: + return linux.TCP_CLOSE + case tcp.StateCloseWait: + return linux.TCP_CLOSE_WAIT + case tcp.StateLastAck: + return linux.TCP_LAST_ACK + case tcp.StateListen: + return linux.TCP_LISTEN + case tcp.StateClosing: + return linux.TCP_CLOSING + default: + // Internal or unknown state. + return 0 + } + } + + // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets. + return 0 +} -- cgit v1.2.3 From 315cf9a523d409dc6ddd5ce25f8f0315068ccc67 Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Thu, 6 Jun 2019 16:59:21 -0700 Subject: Use common definition of SockType. SockType isn't specific to unix domain sockets, and the current definition basically mirrors the linux ABI's definition. PiperOrigin-RevId: 251956740 --- pkg/abi/linux/socket.go | 18 +++++++++++------- pkg/sentry/fs/gofer/socket.go | 11 ++++++----- pkg/sentry/fs/host/socket.go | 11 ++++++----- pkg/sentry/socket/epsocket/BUILD | 1 - pkg/sentry/socket/epsocket/epsocket.go | 11 +++++------ pkg/sentry/socket/epsocket/provider.go | 7 +++---- pkg/sentry/socket/hostinet/BUILD | 1 - pkg/sentry/socket/hostinet/socket.go | 5 ++--- pkg/sentry/socket/netlink/provider.go | 7 +++---- pkg/sentry/socket/rpcinet/BUILD | 1 - pkg/sentry/socket/rpcinet/socket.go | 9 ++++----- pkg/sentry/socket/socket.go | 8 ++++---- pkg/sentry/socket/unix/transport/connectioned.go | 20 ++++++++++---------- pkg/sentry/socket/unix/transport/connectionless.go | 4 ++-- pkg/sentry/socket/unix/transport/unix.go | 22 ++++------------------ pkg/sentry/socket/unix/unix.go | 4 ++-- pkg/sentry/strace/socket.go | 14 +++++++------- pkg/sentry/syscalls/linux/sys_socket.go | 4 ++-- 18 files changed, 71 insertions(+), 87 deletions(-) (limited to 'pkg/sentry/socket/epsocket') diff --git a/pkg/abi/linux/socket.go b/pkg/abi/linux/socket.go index 44bd69df6..a714ac86d 100644 --- a/pkg/abi/linux/socket.go +++ b/pkg/abi/linux/socket.go @@ -102,15 +102,19 @@ const ( SOL_NETLINK = 270 ) +// A SockType is a type (as opposed to family) of sockets. These are enumerated +// below as SOCK_* constants. +type SockType int + // Socket types, from linux/net.h. const ( - SOCK_STREAM = 1 - SOCK_DGRAM = 2 - SOCK_RAW = 3 - SOCK_RDM = 4 - SOCK_SEQPACKET = 5 - SOCK_DCCP = 6 - SOCK_PACKET = 10 + SOCK_STREAM SockType = 1 + SOCK_DGRAM = 2 + SOCK_RAW = 3 + SOCK_RDM = 4 + SOCK_SEQPACKET = 5 + SOCK_DCCP = 6 + SOCK_PACKET = 10 ) // SOCK_TYPE_MASK covers all of the above socket types. The remaining bits are diff --git a/pkg/sentry/fs/gofer/socket.go b/pkg/sentry/fs/gofer/socket.go index 7376fd76f..7ac0a421f 100644 --- a/pkg/sentry/fs/gofer/socket.go +++ b/pkg/sentry/fs/gofer/socket.go @@ -15,6 +15,7 @@ package gofer import ( + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/log" "gvisor.googlesource.com/gvisor/pkg/p9" "gvisor.googlesource.com/gvisor/pkg/sentry/fs" @@ -61,13 +62,13 @@ type endpoint struct { path string } -func unixSockToP9(t transport.SockType) (p9.ConnectFlags, bool) { +func sockTypeToP9(t linux.SockType) (p9.ConnectFlags, bool) { switch t { - case transport.SockStream: + case linux.SOCK_STREAM: return p9.StreamSocket, true - case transport.SockSeqpacket: + case linux.SOCK_SEQPACKET: return p9.SeqpacketSocket, true - case transport.SockDgram: + case linux.SOCK_DGRAM: return p9.DgramSocket, true } return 0, false @@ -75,7 +76,7 @@ func unixSockToP9(t transport.SockType) (p9.ConnectFlags, bool) { // BidirectionalConnect implements ConnectableEndpoint.BidirectionalConnect. func (e *endpoint) BidirectionalConnect(ce transport.ConnectingEndpoint, returnConnect func(transport.Receiver, transport.ConnectedEndpoint)) *syserr.Error { - cf, ok := unixSockToP9(ce.Type()) + cf, ok := sockTypeToP9(ce.Type()) if !ok { return syserr.ErrConnectionRefused } diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index e4ec0f62c..6423ad938 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -19,6 +19,7 @@ import ( "sync" "syscall" + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/fd" "gvisor.googlesource.com/gvisor/pkg/fdnotifier" "gvisor.googlesource.com/gvisor/pkg/log" @@ -56,7 +57,7 @@ type ConnectedEndpoint struct { srfd int `state:"wait"` // stype is the type of Unix socket. - stype transport.SockType + stype linux.SockType // sndbuf is the size of the send buffer. // @@ -105,7 +106,7 @@ func (c *ConnectedEndpoint) init() *syserr.Error { return syserr.ErrInvalidEndpointState } - c.stype = transport.SockType(stype) + c.stype = linux.SockType(stype) c.sndbuf = sndbuf return nil @@ -163,7 +164,7 @@ func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.F ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.NewWithDirent(ctx, d, ep, e.stype != transport.SockStream, flags), nil + return unixsocket.NewWithDirent(ctx, d, ep, e.stype != linux.SOCK_STREAM, flags), nil } // newSocket allocates a new unix socket with host endpoint. @@ -195,7 +196,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.New(ctx, ep, e.stype != transport.SockStream), nil + return unixsocket.New(ctx, ep, e.stype != linux.SOCK_STREAM), nil } // Send implements transport.ConnectedEndpoint.Send. @@ -209,7 +210,7 @@ func (c *ConnectedEndpoint) Send(data [][]byte, controlMessages transport.Contro // Since stream sockets don't preserve message boundaries, we can write // only as much of the message as fits in the send buffer. - truncate := c.stype == transport.SockStream + truncate := c.stype == linux.SOCK_STREAM n, totalLen, err := fdWriteVec(c.file.FD(), data, c.sndbuf, truncate) if n < totalLen && err == nil { diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/epsocket/BUILD index 44bb97b5b..7e2679ea0 100644 --- a/pkg/sentry/socket/epsocket/BUILD +++ b/pkg/sentry/socket/epsocket/BUILD @@ -32,7 +32,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", - "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", "//pkg/syserr", diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index f91c5127a..e1e29de35 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -44,7 +44,6 @@ import ( ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/safemem" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/unimpl" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" @@ -228,7 +227,7 @@ type SocketOperations struct { family int Endpoint tcpip.Endpoint - skType transport.SockType + skType linux.SockType // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` @@ -253,8 +252,8 @@ type SocketOperations struct { } // New creates a new endpoint socket. -func New(t *kernel.Task, family int, skType transport.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { - if skType == transport.SockStream { +func New(t *kernel.Task, family int, skType linux.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { + if skType == linux.SOCK_STREAM { if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil { return nil, syserr.TranslateNetstackError(err) } @@ -638,7 +637,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) ( // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -664,7 +663,7 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, } // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_TYPE: diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go index ec930d8d5..e48a106ea 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/epsocket/provider.go @@ -23,7 +23,6 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/header" @@ -42,7 +41,7 @@ type provider struct { // getTransportProtocol figures out transport protocol. Currently only TCP, // UDP, and ICMP are supported. -func getTransportProtocol(ctx context.Context, stype transport.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) { +func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) { switch stype { case linux.SOCK_STREAM: if protocol != 0 && protocol != syscall.IPPROTO_TCP { @@ -80,7 +79,7 @@ func getTransportProtocol(ctx context.Context, stype transport.SockType, protoco } // Socket creates a new socket object for the AF_INET or AF_INET6 family. -func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Fail right away if we don't have a stack. stack := t.NetworkContext() if stack == nil { @@ -116,7 +115,7 @@ func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int } // Pair just returns nil sockets (not supported). -func (*provider) Pair(*kernel.Task, transport.SockType, int) (*fs.File, *fs.File, *syserr.Error) { +func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { return nil, nil, nil } diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index a469af7ac..975f47bc3 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -30,7 +30,6 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/safemem", "//pkg/sentry/socket", - "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/syserror", diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 0d75580a3..4517951a0 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -30,7 +30,6 @@ import ( ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time" "gvisor.googlesource.com/gvisor/pkg/sentry/safemem" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/syserror" @@ -548,7 +547,7 @@ type socketProvider struct { } // Socket implements socket.Provider.Socket. -func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Check that we are using the host network stack. stack := t.NetworkContext() if stack == nil { @@ -590,7 +589,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p } // Pair implements socket.Provider.Pair. -func (p *socketProvider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { // Not supported by AF_INET/AF_INET6. return nil, nil, nil } diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 76cf12fd4..863edc241 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -22,7 +22,6 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" "gvisor.googlesource.com/gvisor/pkg/sentry/socket" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/syserr" ) @@ -66,10 +65,10 @@ type socketProvider struct { } // Socket implements socket.Provider.Socket. -func (*socketProvider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Netlink sockets must be specified as datagram or raw, but they // behave the same regardless of type. - if stype != transport.SockDgram && stype != transport.SockRaw { + if stype != linux.SOCK_DGRAM && stype != linux.SOCK_RAW { return nil, syserr.ErrSocketNotSupported } @@ -94,7 +93,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype transport.SockType, protocol } // Pair implements socket.Provider.Pair by returning an error. -func (*socketProvider) Pair(*kernel.Task, transport.SockType, int) (*fs.File, *fs.File, *syserr.Error) { +func (*socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { // Netlink sockets never supports creating socket pairs. return nil, nil, syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD index 4da14a1e0..33ba20de7 100644 --- a/pkg/sentry/socket/rpcinet/BUILD +++ b/pkg/sentry/socket/rpcinet/BUILD @@ -31,7 +31,6 @@ go_library( "//pkg/sentry/socket/hostinet", "//pkg/sentry/socket/rpcinet/conn", "//pkg/sentry/socket/rpcinet/notifier", - "//pkg/sentry/socket/unix/transport", "//pkg/sentry/unimpl", "//pkg/sentry/usermem", "//pkg/syserr", diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index bf42bdf69..2d5b5b58f 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -32,7 +32,6 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/conn" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/notifier" pb "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto" - "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" "gvisor.googlesource.com/gvisor/pkg/sentry/unimpl" "gvisor.googlesource.com/gvisor/pkg/sentry/usermem" "gvisor.googlesource.com/gvisor/pkg/syserr" @@ -70,7 +69,7 @@ type socketOperations struct { var _ = socket.Socket(&socketOperations{}) // New creates a new RPC socket. -func newSocketFile(ctx context.Context, stack *Stack, family int, skType int, protocol int) (*fs.File, *syserr.Error) { +func newSocketFile(ctx context.Context, stack *Stack, family int, skType linux.SockType, protocol int) (*fs.File, *syserr.Error) { id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(family), Type: int64(skType | syscall.SOCK_NONBLOCK), Protocol: int64(protocol)}}}, false /* ignoreResult */) <-c @@ -841,7 +840,7 @@ type socketProvider struct { } // Socket implements socket.Provider.Socket. -func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Check that we are using the RPC network stack. stack := t.NetworkContext() if stack == nil { @@ -857,7 +856,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p // // Try to restrict the flags we will accept to minimize backwards // incompatibility with netstack. - stype := int(stypeflags) & linux.SOCK_TYPE_MASK + stype := stypeflags & linux.SOCK_TYPE_MASK switch stype { case syscall.SOCK_STREAM: switch protocol { @@ -881,7 +880,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p } // Pair implements socket.Provider.Pair. -func (p *socketProvider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { // Not supported by AF_INET/AF_INET6. return nil, nil, nil } diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index a99423365..f1021ec67 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -130,12 +130,12 @@ type Provider interface { // If a nil Socket _and_ a nil error is returned, it means that the // protocol is not supported. A non-nil error should only be returned // if the protocol is supported, but an error occurs during creation. - Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) + Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) // Pair creates a pair of connected sockets. // // See Socket for error information. - Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) + Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) } // families holds a map of all known address families and their providers. @@ -149,7 +149,7 @@ func RegisterProvider(family int, provider Provider) { } // New creates a new socket with the given family, type and protocol. -func New(t *kernel.Task, family int, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func New(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { for _, p := range families[family] { s, err := p.Socket(t, stype, protocol) if err != nil { @@ -166,7 +166,7 @@ func New(t *kernel.Task, family int, stype transport.SockType, protocol int) (*f // Pair creates a new connected socket pair with the given family, type and // protocol. -func Pair(t *kernel.Task, family int, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func Pair(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { providers, ok := families[family] if !ok { return nil, nil, syserr.ErrAddressFamilyNotSupported diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index 9c8ec0365..db79ac904 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -45,7 +45,7 @@ type ConnectingEndpoint interface { // Type returns the socket type, typically either SockStream or // SockSeqpacket. The connection attempt must be aborted if this // value doesn't match the ConnectableEndpoint's type. - Type() SockType + Type() linux.SockType // GetLocalAddress returns the bound path. GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) @@ -101,7 +101,7 @@ type connectionedEndpoint struct { // stype is used by connecting sockets to ensure that they are the // same type. The value is typically either tcpip.SockSeqpacket or // tcpip.SockStream. - stype SockType + stype linux.SockType // acceptedChan is per the TCP endpoint implementation. Note that the // sockets in this channel are _already in the connected state_, and @@ -112,7 +112,7 @@ type connectionedEndpoint struct { } // NewConnectioned creates a new unbound connectionedEndpoint. -func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint { +func NewConnectioned(stype linux.SockType, uid UniqueIDProvider) Endpoint { return &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), @@ -122,7 +122,7 @@ func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint { } // NewPair allocates a new pair of connected unix-domain connectionedEndpoints. -func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { +func NewPair(stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { a := &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}}, id: uid.UniqueID(), @@ -139,7 +139,7 @@ func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit} q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit} - if stype == SockStream { + if stype == linux.SOCK_STREAM { a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}} b.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q2}} } else { @@ -163,7 +163,7 @@ func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) { // NewExternal creates a new externally backed Endpoint. It behaves like a // socketpair. -func NewExternal(stype SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { +func NewExternal(stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint { return &connectionedEndpoint{ baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected}, id: uid.UniqueID(), @@ -178,7 +178,7 @@ func (e *connectionedEndpoint) ID() uint64 { } // Type implements ConnectingEndpoint.Type and Endpoint.Type. -func (e *connectionedEndpoint) Type() SockType { +func (e *connectionedEndpoint) Type() linux.SockType { return e.stype } @@ -294,7 +294,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur } writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit} - if e.stype == SockStream { + if e.stype == linux.SOCK_STREAM { ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}} } else { ne.receiver = &queueReceiver{readQueue: writeQueue} @@ -309,7 +309,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur writeQueue: writeQueue, } readQueue.IncRef() - if e.stype == SockStream { + if e.stype == linux.SOCK_STREAM { returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected) } else { returnConnect(&queueReceiver{readQueue: readQueue}, connected) @@ -429,7 +429,7 @@ func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syser func (e *connectionedEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) { // Stream sockets do not support specifying the endpoint. Seqpacket // sockets ignore the passed endpoint. - if e.stype == SockStream && to != nil { + if e.stype == linux.SOCK_STREAM && to != nil { return 0, syserr.ErrNotSupported } return e.baseEndpoint.SendMsg(data, c, to) diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index c034cf984..81ebfba10 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -119,8 +119,8 @@ func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to Bo } // Type implements Endpoint.Type. -func (e *connectionlessEndpoint) Type() SockType { - return SockDgram +func (e *connectionlessEndpoint) Type() linux.SockType { + return linux.SOCK_DGRAM } // Connect attempts to connect directly to server. diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 5fc09af55..5c55c529e 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -19,6 +19,7 @@ import ( "sync" "sync/atomic" + "gvisor.googlesource.com/gvisor/pkg/abi/linux" "gvisor.googlesource.com/gvisor/pkg/syserr" "gvisor.googlesource.com/gvisor/pkg/tcpip" "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer" @@ -28,21 +29,6 @@ import ( // initialLimit is the starting limit for the socket buffers. const initialLimit = 16 * 1024 -// A SockType is a type (as opposed to family) of sockets. These are enumerated -// in the syscall package as syscall.SOCK_* constants. -type SockType int - -const ( - // SockStream corresponds to syscall.SOCK_STREAM. - SockStream SockType = 1 - // SockDgram corresponds to syscall.SOCK_DGRAM. - SockDgram SockType = 2 - // SockRaw corresponds to syscall.SOCK_RAW. - SockRaw SockType = 3 - // SockSeqpacket corresponds to syscall.SOCK_SEQPACKET. - SockSeqpacket SockType = 5 -) - // A RightsControlMessage is a control message containing FDs. type RightsControlMessage interface { // Clone returns a copy of the RightsControlMessage. @@ -175,7 +161,7 @@ type Endpoint interface { // Type return the socket type, typically either SockStream, SockDgram // or SockSeqpacket. - Type() SockType + Type() linux.SockType // GetLocalAddress returns the address to which the endpoint is bound. GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) @@ -629,7 +615,7 @@ type connectedEndpoint struct { GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) // Type implements Endpoint.Type. - Type() SockType + Type() linux.SockType } writeQueue *queue @@ -653,7 +639,7 @@ func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, } truncate := false - if e.endpoint.Type() == SockStream { + if e.endpoint.Type() == linux.SOCK_STREAM { // Since stream sockets don't preserve message boundaries, we // can write only as much of the message as fits in the queue. truncate = true diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 375542350..56ed63e21 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -605,7 +605,7 @@ func (s *SocketOperations) State() uint32 { type provider struct{} // Socket returns a new unix domain socket. -func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) { +func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { // Check arguments. if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ { return nil, syserr.ErrProtocolNotSupported @@ -631,7 +631,7 @@ func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) } // Pair creates a new pair of AF_UNIX connected sockets. -func (*provider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { +func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { // Check arguments. if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ { return nil, nil, syserr.ErrProtocolNotSupported diff --git a/pkg/sentry/strace/socket.go b/pkg/sentry/strace/socket.go index dbe53b9a2..0b5ef84c4 100644 --- a/pkg/sentry/strace/socket.go +++ b/pkg/sentry/strace/socket.go @@ -76,13 +76,13 @@ var SocketFamily = abi.ValueSet{ // SocketType are the possible socket(2) types. var SocketType = abi.ValueSet{ - linux.SOCK_STREAM: "SOCK_STREAM", - linux.SOCK_DGRAM: "SOCK_DGRAM", - linux.SOCK_RAW: "SOCK_RAW", - linux.SOCK_RDM: "SOCK_RDM", - linux.SOCK_SEQPACKET: "SOCK_SEQPACKET", - linux.SOCK_DCCP: "SOCK_DCCP", - linux.SOCK_PACKET: "SOCK_PACKET", + uint64(linux.SOCK_STREAM): "SOCK_STREAM", + uint64(linux.SOCK_DGRAM): "SOCK_DGRAM", + uint64(linux.SOCK_RAW): "SOCK_RAW", + uint64(linux.SOCK_RDM): "SOCK_RDM", + uint64(linux.SOCK_SEQPACKET): "SOCK_SEQPACKET", + uint64(linux.SOCK_DCCP): "SOCK_DCCP", + uint64(linux.SOCK_PACKET): "SOCK_PACKET", } // SocketFlagSet are the possible socket(2) flags. diff --git a/pkg/sentry/syscalls/linux/sys_socket.go b/pkg/sentry/syscalls/linux/sys_socket.go index 8f4dbf3bc..31295a6a9 100644 --- a/pkg/sentry/syscalls/linux/sys_socket.go +++ b/pkg/sentry/syscalls/linux/sys_socket.go @@ -188,7 +188,7 @@ func Socket(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal } // Create the new socket. - s, e := socket.New(t, domain, transport.SockType(stype&0xf), protocol) + s, e := socket.New(t, domain, linux.SockType(stype&0xf), protocol) if e != nil { return 0, nil, e.ToError() } @@ -227,7 +227,7 @@ func SocketPair(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Sy } // Create the socket pair. - s1, s2, e := socket.Pair(t, domain, transport.SockType(stype&0xf), protocol) + s1, s2, e := socket.Pair(t, domain, linux.SockType(stype&0xf), protocol) if e != nil { return 0, nil, e.ToError() } -- cgit v1.2.3 From a00157cc0e216a9829f2659ce35c856a22aa5ba2 Mon Sep 17 00:00:00 2001 From: Rahat Mahmood Date: Mon, 10 Jun 2019 15:16:42 -0700 Subject: Store more information in the kernel socket table. Store enough information in the kernel socket table to distinguish between different types of sockets. Previously we were only storing the socket family, but this isn't enough to classify sockets. For example, TCPv4 and UDPv4 sockets are both AF_INET, and ICMP sockets are SOCK_DGRAM sockets with a particular protocol. Instead of creating more sub-tables, flatten the socket table and provide a filtering mechanism based on the socket entry. Also generate and store a socket entry index ("sl" in linux) which allows us to output entries in a stable order from procfs. PiperOrigin-RevId: 252495895 --- pkg/sentry/fs/host/socket.go | 4 +-- pkg/sentry/fs/proc/BUILD | 1 + pkg/sentry/fs/proc/net.go | 14 ++++---- pkg/sentry/kernel/BUILD | 13 +++++++ pkg/sentry/kernel/kernel.go | 55 +++++++++++++--------------- pkg/sentry/socket/epsocket/epsocket.go | 13 +++++-- pkg/sentry/socket/epsocket/provider.go | 2 +- pkg/sentry/socket/hostinet/socket.go | 32 +++++++++++------ pkg/sentry/socket/netlink/provider.go | 2 +- pkg/sentry/socket/netlink/socket.go | 12 ++++++- pkg/sentry/socket/rpcinet/socket.go | 16 +++++++-- pkg/sentry/socket/socket.go | 9 +++-- pkg/sentry/socket/unix/unix.go | 65 ++++++++++++++++++++-------------- 13 files changed, 152 insertions(+), 86 deletions(-) (limited to 'pkg/sentry/socket/epsocket') diff --git a/pkg/sentry/fs/host/socket.go b/pkg/sentry/fs/host/socket.go index 6423ad938..305eea718 100644 --- a/pkg/sentry/fs/host/socket.go +++ b/pkg/sentry/fs/host/socket.go @@ -164,7 +164,7 @@ func NewSocketWithDirent(ctx context.Context, d *fs.Dirent, f *fd.FD, flags fs.F ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.NewWithDirent(ctx, d, ep, e.stype != linux.SOCK_STREAM, flags), nil + return unixsocket.NewWithDirent(ctx, d, ep, e.stype, flags), nil } // newSocket allocates a new unix socket with host endpoint. @@ -196,7 +196,7 @@ func newSocket(ctx context.Context, orgfd int, saveable bool) (*fs.File, error) ep := transport.NewExternal(e.stype, uniqueid.GlobalProviderFromContext(ctx), &q, e, e) - return unixsocket.New(ctx, ep, e.stype != linux.SOCK_STREAM), nil + return unixsocket.New(ctx, ep, e.stype), nil } // Send implements transport.ConnectedEndpoint.Send. diff --git a/pkg/sentry/fs/proc/BUILD b/pkg/sentry/fs/proc/BUILD index d19c360e0..1728fe0b5 100644 --- a/pkg/sentry/fs/proc/BUILD +++ b/pkg/sentry/fs/proc/BUILD @@ -45,6 +45,7 @@ go_library( "//pkg/sentry/kernel/time", "//pkg/sentry/limits", "//pkg/sentry/mm", + "//pkg/sentry/socket", "//pkg/sentry/socket/rpcinet", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", diff --git a/pkg/sentry/fs/proc/net.go b/pkg/sentry/fs/proc/net.go index 3daaa962c..034950158 100644 --- a/pkg/sentry/fs/proc/net.go +++ b/pkg/sentry/fs/proc/net.go @@ -27,6 +27,7 @@ import ( "gvisor.googlesource.com/gvisor/pkg/sentry/fs/ramfs" "gvisor.googlesource.com/gvisor/pkg/sentry/inet" "gvisor.googlesource.com/gvisor/pkg/sentry/kernel" + "gvisor.googlesource.com/gvisor/pkg/sentry/socket" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix" "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport" ) @@ -213,17 +214,18 @@ func (n *netUnix) ReadSeqFileData(ctx context.Context, h seqfile.SeqHandle) ([]s fmt.Fprintf(&buf, "Num RefCount Protocol Flags Type St Inode Path\n") // Entries - for _, sref := range n.k.ListSockets(linux.AF_UNIX) { - s := sref.Get() + for _, se := range n.k.ListSockets() { + s := se.Sock.Get() if s == nil { - log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", sref) + log.Debugf("Couldn't resolve weakref %v in socket table, racing with destruction?", se.Sock) continue } sfile := s.(*fs.File) - sops, ok := sfile.FileOperations.(*unix.SocketOperations) - if !ok { - panic(fmt.Sprintf("Found non-unix socket file in unix socket table: %+v", sfile)) + if family, _, _ := sfile.FileOperations.(socket.Socket).Type(); family != linux.AF_UNIX { + // Not a unix socket. + continue } + sops := sfile.FileOperations.(*unix.SocketOperations) addr, err := sops.Endpoint().GetLocalAddress() if err != nil { diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index 99a2fd964..04e375910 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -64,6 +64,18 @@ go_template_instance( }, ) +go_template_instance( + name = "socket_list", + out = "socket_list.go", + package = "kernel", + prefix = "socket", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*SocketEntry", + "Linker": "*SocketEntry", + }, +) + proto_library( name = "uncaught_signal_proto", srcs = ["uncaught_signal.proto"], @@ -104,6 +116,7 @@ go_library( "sessions.go", "signal.go", "signal_handlers.go", + "socket_list.go", "syscalls.go", "syscalls_state.go", "syslog.go", diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 85d73ace2..f253a81d9 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -182,9 +182,13 @@ type Kernel struct { // danglingEndpoints is used to save / restore tcpip.DanglingEndpoints. danglingEndpoints struct{} `state:".([]tcpip.Endpoint)"` - // socketTable is used to track all sockets on the system. Protected by + // sockets is the list of all network sockets the system. Protected by // extMu. - socketTable map[int]map[*refs.WeakRef]struct{} + sockets socketList + + // nextSocketEntry is the next entry number to use in sockets. Protected + // by extMu. + nextSocketEntry uint64 // deviceRegistry is used to save/restore device.SimpleDevices. deviceRegistry struct{} `state:".(*device.Registry)"` @@ -283,7 +287,6 @@ func (k *Kernel) Init(args InitKernelArgs) error { k.monotonicClock = &timekeeperClock{tk: args.Timekeeper, c: sentrytime.Monotonic} k.futexes = futex.NewManager() k.netlinkPorts = port.New() - k.socketTable = make(map[int]map[*refs.WeakRef]struct{}) return nil } @@ -1137,51 +1140,43 @@ func (k *Kernel) EmitUnimplementedEvent(ctx context.Context) { }) } -// socketEntry represents a socket recorded in Kernel.socketTable. It implements +// SocketEntry represents a socket recorded in Kernel.sockets. It implements // refs.WeakRefUser for sockets stored in the socket table. // // +stateify savable -type socketEntry struct { - k *Kernel - sock *refs.WeakRef - family int +type SocketEntry struct { + socketEntry + k *Kernel + Sock *refs.WeakRef + ID uint64 // Socket table entry number. } // WeakRefGone implements refs.WeakRefUser.WeakRefGone. -func (s *socketEntry) WeakRefGone() { +func (s *SocketEntry) WeakRefGone() { s.k.extMu.Lock() - // k.socketTable is guaranteed to point to a valid socket table for s.family - // at this point, since we made sure of the fact when we created this - // socketEntry, and we never delete socket tables. - delete(s.k.socketTable[s.family], s.sock) + s.k.sockets.Remove(s) s.k.extMu.Unlock() } // RecordSocket adds a socket to the system-wide socket table for tracking. // // Precondition: Caller must hold a reference to sock. -func (k *Kernel) RecordSocket(sock *fs.File, family int) { +func (k *Kernel) RecordSocket(sock *fs.File) { k.extMu.Lock() - table, ok := k.socketTable[family] - if !ok { - table = make(map[*refs.WeakRef]struct{}) - k.socketTable[family] = table - } - se := socketEntry{k: k, family: family} - se.sock = refs.NewWeakRef(sock, &se) - table[se.sock] = struct{}{} + id := k.nextSocketEntry + k.nextSocketEntry++ + s := &SocketEntry{k: k, ID: id} + s.Sock = refs.NewWeakRef(sock, s) + k.sockets.PushBack(s) k.extMu.Unlock() } -// ListSockets returns a snapshot of all sockets of a given family. -func (k *Kernel) ListSockets(family int) []*refs.WeakRef { +// ListSockets returns a snapshot of all sockets. +func (k *Kernel) ListSockets() []*SocketEntry { k.extMu.Lock() - socks := []*refs.WeakRef{} - if table, ok := k.socketTable[family]; ok { - socks = make([]*refs.WeakRef, 0, len(table)) - for s := range table { - socks = append(socks, s) - } + var socks []*SocketEntry + for s := k.sockets.Front(); s != nil; s = s.Next() { + socks = append(socks, s) } k.extMu.Unlock() return socks diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index e1e29de35..f67451179 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -228,6 +228,7 @@ type SocketOperations struct { family int Endpoint tcpip.Endpoint skType linux.SockType + protocol int // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` @@ -252,7 +253,7 @@ type SocketOperations struct { } // New creates a new endpoint socket. -func New(t *kernel.Task, family int, skType linux.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { +func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil { return nil, syserr.TranslateNetstackError(err) @@ -266,6 +267,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, queue *waiter.Queue, family: family, Endpoint: endpoint, skType: skType, + protocol: protocol, }), nil } @@ -550,7 +552,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } } - ns, err := New(t, s.family, s.skType, wq, ep) + ns, err := New(t, s.family, s.skType, s.protocol, wq, ep) if err != nil { return 0, nil, 0, err } @@ -578,7 +580,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits()) - t.Kernel().RecordSocket(ns, s.family) + t.Kernel().RecordSocket(ns) return fd, addr, addrLen, syserr.FromError(e) } @@ -2324,3 +2326,8 @@ func (s *SocketOperations) State() uint32 { // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets. return 0 } + +// Type implements socket.Socket.Type. +func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { + return s.family, s.skType, s.protocol +} diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go index e48a106ea..516582828 100644 --- a/pkg/sentry/socket/epsocket/provider.go +++ b/pkg/sentry/socket/epsocket/provider.go @@ -111,7 +111,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* return nil, syserr.TranslateNetstackError(e) } - return New(t, p.family, stype, wq, ep) + return New(t, p.family, stype, protocol, wq, ep) } // Pair just returns nil sockets (not supported). diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 4517951a0..c62c8d8f1 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -56,15 +56,22 @@ type socketOperations struct { fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout - family int // Read-only. - fd int // must be O_NONBLOCK - queue waiter.Queue + family int // Read-only. + stype linux.SockType // Read-only. + protocol int // Read-only. + fd int // must be O_NONBLOCK + queue waiter.Queue } var _ = socket.Socket(&socketOperations{}) -func newSocketFile(ctx context.Context, family int, fd int, nonblock bool) (*fs.File, *syserr.Error) { - s := &socketOperations{family: family, fd: fd} +func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) { + s := &socketOperations{ + family: family, + stype: stype, + protocol: protocol, + fd: fd, + } if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil { return nil, syserr.FromError(err) } @@ -222,7 +229,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr) } - f, err := newSocketFile(t, s.family, fd, flags&syscall.SOCK_NONBLOCK != 0) + f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0) if err != nil { syscall.Close(fd) return 0, nil, 0, err @@ -233,7 +240,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, } kfd, kerr := t.FDMap().NewFDFrom(0, f, fdFlags, t.ThreadGroup().Limits()) - t.Kernel().RecordSocket(f, s.family) + t.Kernel().RecordSocket(f) return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr) } @@ -542,6 +549,11 @@ func (s *socketOperations) State() uint32 { return uint32(info.State) } +// Type implements socket.Socket.Type. +func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) { + return s.family, s.stype, s.protocol +} + type socketProvider struct { family int } @@ -558,7 +570,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto } // Only accept TCP and UDP. - stype := int(stypeflags) & linux.SOCK_TYPE_MASK + stype := stypeflags & linux.SOCK_TYPE_MASK switch stype { case syscall.SOCK_STREAM: switch protocol { @@ -581,11 +593,11 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto // Conservatively ignore all flags specified by the application and add // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0 // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent. - fd, err := syscall.Socket(p.family, stype|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) + fd, err := syscall.Socket(p.family, int(stype)|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) if err != nil { return nil, syserr.FromError(err) } - return newSocketFile(t, p.family, fd, stypeflags&syscall.SOCK_NONBLOCK != 0) + return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&syscall.SOCK_NONBLOCK != 0) } // Pair implements socket.Provider.Pair. diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 863edc241..5dc103877 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -82,7 +82,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int return nil, err } - s, err := NewSocket(t, p) + s, err := NewSocket(t, stype, p) if err != nil { return nil, err } diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index 16c79aa33..62659784a 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -80,6 +80,10 @@ type Socket struct { // protocol is the netlink protocol implementation. protocol Protocol + // skType is the socket type. This is either SOCK_DGRAM or SOCK_RAW for + // netlink sockets. + skType linux.SockType + // ep is a datagram unix endpoint used to buffer messages sent from the // kernel to userspace. RecvMsg reads messages from this endpoint. ep transport.Endpoint @@ -105,7 +109,7 @@ type Socket struct { var _ socket.Socket = (*Socket)(nil) // NewSocket creates a new Socket. -func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) { +func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) { // Datagram endpoint used to buffer kernel -> user messages. ep := transport.NewConnectionless() @@ -126,6 +130,7 @@ func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) { return &Socket{ ports: t.Kernel().NetlinkPorts(), protocol: protocol, + skType: skType, ep: ep, connection: connection, sendBufferSize: defaultSendBufferSize, @@ -621,3 +626,8 @@ func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, func (s *Socket) State() uint32 { return s.ep.State() } + +// Type implements socket.Socket.Type. +func (s *Socket) Type() (family int, skType linux.SockType, protocol int) { + return linux.AF_NETLINK, s.skType, s.protocol.Protocol() +} diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go index 2d5b5b58f..c22ff1ff0 100644 --- a/pkg/sentry/socket/rpcinet/socket.go +++ b/pkg/sentry/socket/rpcinet/socket.go @@ -53,7 +53,10 @@ type socketOperations struct { fsutil.FileUseInodeUnstableAttr `state:"nosave"` socket.SendReceiveTimeout - family int // Read-only. + family int // Read-only. + stype linux.SockType // Read-only. + protocol int // Read-only. + fd uint32 // must be O_NONBLOCK wq *waiter.Queue rpcConn *conn.RPCConnection @@ -86,6 +89,8 @@ func newSocketFile(ctx context.Context, stack *Stack, family int, skType linux.S defer dirent.DecRef() return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &socketOperations{ family: family, + stype: skType, + protocol: protocol, wq: &wq, fd: fd, rpcConn: stack.rpcConn, @@ -332,7 +337,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, if err != nil { return 0, nil, 0, syserr.FromError(err) } - t.Kernel().RecordSocket(file, s.family) + t.Kernel().RecordSocket(file) if peerRequested { return fd, payload.Address.Address, payload.Address.Length, nil @@ -835,6 +840,11 @@ func (s *socketOperations) State() uint32 { return 0 } +// Type implements socket.Socket.Type. +func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) { + return s.family, s.stype, s.protocol +} + type socketProvider struct { family int } @@ -876,7 +886,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, proto return nil, nil } - return newSocketFile(t, s, p.family, stype, 0) + return newSocketFile(t, s, p.family, stype, protocol) } // Pair implements socket.Provider.Pair. diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index f1021ec67..d60944b6b 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -120,6 +120,9 @@ type Socket interface { // State returns the current state of the socket, as represented by Linux in // procfs. The returned state value is protocol-specific. State() uint32 + + // Type returns the family, socket type and protocol of the socket. + Type() (family int, skType linux.SockType, protocol int) } // Provider is the interface implemented by providers of sockets for specific @@ -156,7 +159,7 @@ func New(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.Fi return nil, err } if s != nil { - t.Kernel().RecordSocket(s, family) + t.Kernel().RecordSocket(s) return s, nil } } @@ -179,8 +182,8 @@ func Pair(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.F } if s1 != nil && s2 != nil { k := t.Kernel() - k.RecordSocket(s1, family) - k.RecordSocket(s2, family) + k.RecordSocket(s1) + k.RecordSocket(s2) return s1, s2, nil } } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 56ed63e21..b07e8d67b 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -17,6 +17,7 @@ package unix import ( + "fmt" "strings" "syscall" @@ -55,22 +56,22 @@ type SocketOperations struct { refs.AtomicRefCount socket.SendReceiveTimeout - ep transport.Endpoint - isPacket bool + ep transport.Endpoint + stype linux.SockType } // New creates a new unix socket. -func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File { +func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File { dirent := socket.NewDirent(ctx, unixSocketDevice) defer dirent.DecRef() - return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true}) + return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true}) } // NewWithDirent creates a new unix socket using an existing dirent. -func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File { +func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, stype linux.SockType, flags fs.FileFlags) *fs.File { return fs.NewFile(ctx, d, flags, &SocketOperations{ - ep: ep, - isPacket: isPacket, + ep: ep, + stype: stype, }) } @@ -88,6 +89,18 @@ func (s *SocketOperations) Release() { s.DecRef() } +func (s *SocketOperations) isPacket() bool { + switch s.stype { + case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: + return true + case linux.SOCK_STREAM: + return false + default: + // We shouldn't have allowed any other socket types during creation. + panic(fmt.Sprintf("Invalid socket type %d", s.stype)) + } +} + // Endpoint extracts the transport.Endpoint. func (s *SocketOperations) Endpoint() transport.Endpoint { return s.ep @@ -193,7 +206,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } } - ns := New(t, ep, s.isPacket) + ns := New(t, ep, s.stype) defer ns.DecRef() if flags&linux.SOCK_NONBLOCK != 0 { @@ -221,7 +234,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, return 0, nil, 0, syserr.FromError(e) } - t.Kernel().RecordSocket(ns, linux.AF_UNIX) + t.Kernel().RecordSocket(ns) return fd, addr, addrLen, nil } @@ -487,6 +500,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 waitAll := flags&linux.MSG_WAITALL != 0 + isPacket := s.isPacket() // Calculate the number of FDs for which we have space and if we are // requesting credentials. @@ -528,8 +542,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags msgFlags |= linux.MSG_CTRUNC } - if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() { - if s.isPacket && n < int64(r.MsgSize) { + if err != nil || dontWait || !waitAll || isPacket || n >= dst.NumBytes() { + if isPacket && n < int64(r.MsgSize) { msgFlags |= linux.MSG_TRUNC } @@ -570,11 +584,11 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags total += n } - if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() { + if err != nil || !waitAll || isPacket || n >= dst.NumBytes() { if total > 0 { err = nil } - if s.isPacket && n < int64(r.MsgSize) { + if isPacket && n < int64(r.MsgSize) { msgFlags |= linux.MSG_TRUNC } return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err) @@ -601,6 +615,12 @@ func (s *SocketOperations) State() uint32 { return s.ep.State() } +// Type implements socket.Socket.Type. +func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { + // Unix domain sockets always have a protocol of 0. + return linux.AF_UNIX, s.stype, 0 +} + // provider is a unix domain socket provider. type provider struct{} @@ -613,21 +633,16 @@ func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs // Create the endpoint and socket. var ep transport.Endpoint - var isPacket bool switch stype { case linux.SOCK_DGRAM: - isPacket = true ep = transport.NewConnectionless() - case linux.SOCK_SEQPACKET: - isPacket = true - fallthrough - case linux.SOCK_STREAM: + case linux.SOCK_SEQPACKET, linux.SOCK_STREAM: ep = transport.NewConnectioned(stype, t.Kernel()) default: return nil, syserr.ErrInvalidArgument } - return New(t, ep, isPacket), nil + return New(t, ep, stype), nil } // Pair creates a new pair of AF_UNIX connected sockets. @@ -637,19 +652,17 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F return nil, nil, syserr.ErrProtocolNotSupported } - var isPacket bool switch stype { - case linux.SOCK_STREAM: - case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: - isPacket = true + case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: + // Ok default: return nil, nil, syserr.ErrInvalidArgument } // Create the endpoints and sockets. ep1, ep2 := transport.NewPair(stype, t.Kernel()) - s1 := New(t, ep1, isPacket) - s2 := New(t, ep2, isPacket) + s1 := New(t, ep1, stype) + s2 := New(t, ep2, stype) return s1, s2, nil } -- cgit v1.2.3 From 70578806e8d3e01fae2249b3e602cd5b05d378a0 Mon Sep 17 00:00:00 2001 From: Bhasker Hariharan Date: Wed, 12 Jun 2019 13:34:47 -0700 Subject: Add support for TCP_CONGESTION socket option. This CL also cleans up the error returned for setting congestion control which was incorrectly returning EINVAL instead of ENOENT. PiperOrigin-RevId: 252889093 --- pkg/sentry/socket/epsocket/epsocket.go | 30 +++++ pkg/tcpip/tcpip.go | 8 ++ pkg/tcpip/transport/tcp/BUILD | 1 + pkg/tcpip/transport/tcp/cubic.go | 3 +- pkg/tcpip/transport/tcp/cubic_state.go | 29 +++++ pkg/tcpip/transport/tcp/endpoint.go | 45 +++++++- pkg/tcpip/transport/tcp/protocol.go | 22 ++-- pkg/tcpip/transport/tcp/snd.go | 10 +- pkg/tcpip/transport/tcp/tcp_test.go | 112 +++++++++++++++--- pkg/tcpip/transport/tcp/testing/context/context.go | 52 +++++---- test/syscalls/linux/socket_ip_tcp_generic.cc | 104 +++++++++++++++++ test/syscalls/linux/tcp_socket.cc | 127 +++++++++++++++++++++ 12 files changed, 481 insertions(+), 62 deletions(-) create mode 100644 pkg/tcpip/transport/tcp/cubic_state.go (limited to 'pkg/sentry/socket/epsocket') diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go index f67451179..a50798cb3 100644 --- a/pkg/sentry/socket/epsocket/epsocket.go +++ b/pkg/sentry/socket/epsocket/epsocket.go @@ -920,6 +920,30 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa t.Kernel().EmitUnimplementedEvent(t) + case linux.TCP_CONGESTION: + if outLen <= 0 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.CongestionControlOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + // We match linux behaviour here where it returns the lower of + // TCP_CA_NAME_MAX bytes or the value of the option length. + // + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const tcpCANameMax = 16 + + toCopy := tcpCANameMax + if outLen < tcpCANameMax { + toCopy = outLen + } + b := make([]byte, toCopy) + copy(b, v) + return b, nil + default: emitUnimplementedEventTCP(t, name) } @@ -1222,6 +1246,12 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + case linux.TCP_CONGESTION: + v := tcpip.CongestionControlOption(optVal) + if err := ep.SetSockOpt(v); err != nil { + return syserr.TranslateNetstackError(err) + } + return nil case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 85ef014d0..04c776205 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -472,6 +472,14 @@ type KeepaliveIntervalOption time.Duration // closed. type KeepaliveCountOption int +// CongestionControlOption is used by SetSockOpt/GetSockOpt to set/get +// the current congestion control algorithm. +type CongestionControlOption string + +// AvailableCongestionControlOption is used to query the supported congestion +// control algorithms. +type AvailableCongestionControlOption string + // MulticastTTLOption is used by SetSockOpt/GetSockOpt to control the default // TTL value for multicast messages. The default is 1. type MulticastTTLOption uint8 diff --git a/pkg/tcpip/transport/tcp/BUILD b/pkg/tcpip/transport/tcp/BUILD index 9db38196b..a9dbfb930 100644 --- a/pkg/tcpip/transport/tcp/BUILD +++ b/pkg/tcpip/transport/tcp/BUILD @@ -21,6 +21,7 @@ go_library( "accept.go", "connect.go", "cubic.go", + "cubic_state.go", "endpoint.go", "endpoint_state.go", "forwarder.go", diff --git a/pkg/tcpip/transport/tcp/cubic.go b/pkg/tcpip/transport/tcp/cubic.go index e618cd2b9..7b1f5e763 100644 --- a/pkg/tcpip/transport/tcp/cubic.go +++ b/pkg/tcpip/transport/tcp/cubic.go @@ -23,6 +23,7 @@ import ( // control algorithm state. // // See: https://tools.ietf.org/html/rfc8312. +// +stateify savable type cubicState struct { // wLastMax is the previous wMax value. wLastMax float64 @@ -33,7 +34,7 @@ type cubicState struct { // t denotes the time when the current congestion avoidance // was entered. - t time.Time + t time.Time `state:".(unixTime)"` // numCongestionEvents tracks the number of congestion events since last // RTO. diff --git a/pkg/tcpip/transport/tcp/cubic_state.go b/pkg/tcpip/transport/tcp/cubic_state.go new file mode 100644 index 000000000..d0f58cfaf --- /dev/null +++ b/pkg/tcpip/transport/tcp/cubic_state.go @@ -0,0 +1,29 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tcp + +import ( + "time" +) + +// saveT is invoked by stateify. +func (c *cubicState) saveT() unixTime { + return unixTime{c.t.Unix(), c.t.UnixNano()} +} + +// loadT is invoked by stateify. +func (c *cubicState) loadT(unix unixTime) { + c.t = time.Unix(unix.second, unix.nano) +} diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 23422ca5e..1efe9d3fb 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -17,6 +17,7 @@ package tcp import ( "fmt" "math" + "strings" "sync" "sync/atomic" "time" @@ -286,7 +287,7 @@ type endpoint struct { // cc stores the name of the Congestion Control algorithm to use for // this endpoint. - cc CongestionControlOption + cc tcpip.CongestionControlOption // The following are used when a "packet too big" control packet is // received. They are protected by sndBufMu. They are used to @@ -394,7 +395,7 @@ func newEndpoint(stack *stack.Stack, netProto tcpip.NetworkProtocolNumber, waite e.rcvBufSize = rs.Default } - var cs CongestionControlOption + var cs tcpip.CongestionControlOption if err := stack.TransportProtocolOption(ProtocolNumber, &cs); err == nil { e.cc = cs } @@ -898,6 +899,40 @@ func (e *endpoint) SetSockOpt(opt interface{}) *tcpip.Error { e.mu.Unlock() return nil + case tcpip.CongestionControlOption: + // Query the available cc algorithms in the stack and + // validate that the specified algorithm is actually + // supported in the stack. + var avail tcpip.AvailableCongestionControlOption + if err := e.stack.TransportProtocolOption(ProtocolNumber, &avail); err != nil { + return err + } + availCC := strings.Split(string(avail), " ") + for _, cc := range availCC { + if v == tcpip.CongestionControlOption(cc) { + // Acquire the work mutex as we may need to + // reinitialize the congestion control state. + e.mu.Lock() + state := e.state + e.cc = v + e.mu.Unlock() + switch state { + case StateEstablished: + e.workMu.Lock() + e.mu.Lock() + if e.state == state { + e.snd.cc = e.snd.initCongestionControl(e.cc) + } + e.mu.Unlock() + e.workMu.Unlock() + } + return nil + } + } + + // Linux returns ENOENT when an invalid congestion + // control algorithm is specified. + return tcpip.ErrNoSuchFile default: return nil } @@ -1067,6 +1102,12 @@ func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { } return nil + case *tcpip.CongestionControlOption: + e.mu.Lock() + *o = e.cc + e.mu.Unlock() + return nil + default: return tcpip.ErrUnknownProtocolOption } diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index b31bcccfa..59f4009a1 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -79,13 +79,6 @@ const ( ccCubic = "cubic" ) -// CongestionControlOption sets the current congestion control algorithm. -type CongestionControlOption string - -// AvailableCongestionControlOption returns the supported congestion control -// algorithms. -type AvailableCongestionControlOption string - type protocol struct { mu sync.Mutex sackEnabled bool @@ -93,7 +86,6 @@ type protocol struct { recvBufferSize ReceiveBufferSizeOption congestionControl string availableCongestionControl []string - allowedCongestionControl []string } // Number returns the tcp protocol number. @@ -188,7 +180,7 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { p.mu.Unlock() return nil - case CongestionControlOption: + case tcpip.CongestionControlOption: for _, c := range p.availableCongestionControl { if string(v) == c { p.mu.Lock() @@ -197,7 +189,9 @@ func (p *protocol) SetOption(option interface{}) *tcpip.Error { return nil } } - return tcpip.ErrInvalidOptionValue + // linux returns ENOENT when an invalid congestion control + // is specified. + return tcpip.ErrNoSuchFile default: return tcpip.ErrUnknownProtocolOption } @@ -223,14 +217,14 @@ func (p *protocol) Option(option interface{}) *tcpip.Error { *v = p.recvBufferSize p.mu.Unlock() return nil - case *CongestionControlOption: + case *tcpip.CongestionControlOption: p.mu.Lock() - *v = CongestionControlOption(p.congestionControl) + *v = tcpip.CongestionControlOption(p.congestionControl) p.mu.Unlock() return nil - case *AvailableCongestionControlOption: + case *tcpip.AvailableCongestionControlOption: p.mu.Lock() - *v = AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) + *v = tcpip.AvailableCongestionControlOption(strings.Join(p.availableCongestionControl, " ")) p.mu.Unlock() return nil default: diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index b236d7af2..daa3e8341 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -194,8 +194,6 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint s := &sender{ ep: ep, - sndCwnd: InitialCwnd, - sndSsthresh: math.MaxInt64, sndWnd: sndWnd, sndUna: iss + 1, sndNxt: iss + 1, @@ -238,7 +236,13 @@ func newSender(ep *endpoint, iss, irs seqnum.Value, sndWnd seqnum.Size, mss uint return s } -func (s *sender) initCongestionControl(congestionControlName CongestionControlOption) congestionControl { +// initCongestionControl initializes the specified congestion control module and +// returns a handle to it. It also initializes the sndCwnd and sndSsThresh to +// their initial values. +func (s *sender) initCongestionControl(congestionControlName tcpip.CongestionControlOption) congestionControl { + s.sndCwnd = InitialCwnd + s.sndSsthresh = math.MaxInt64 + switch congestionControlName { case ccCubic: return newCubicCC(s) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 779ca8b76..7d8987219 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -3205,13 +3205,14 @@ func TestTCPEndpointProbe(t *testing.T) { } } -func TestSetCongestionControl(t *testing.T) { +func TestStackSetCongestionControl(t *testing.T) { testCases := []struct { - cc tcp.CongestionControlOption - mustPass bool + cc tcpip.CongestionControlOption + err *tcpip.Error }{ - {"reno", true}, - {"cubic", true}, + {"reno", nil}, + {"cubic", nil}, + {"blahblah", tcpip.ErrNoSuchFile}, } for _, tc := range testCases { @@ -3221,62 +3222,135 @@ func TestSetCongestionControl(t *testing.T) { s := c.Stack() - if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != nil && tc.mustPass { - t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want not-nil", tcp.ProtocolNumber, tc.cc, err) + var oldCC tcpip.CongestionControlOption + if err := s.TransportProtocolOption(tcp.ProtocolNumber, &oldCC); err != nil { + t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &oldCC, err) } - var cc tcp.CongestionControlOption + if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, tc.cc); err != tc.err { + t.Fatalf("s.SetTransportProtocolOption(%v, %v) = %v, want %v", tcp.ProtocolNumber, tc.cc, err, tc.err) + } + + var cc tcpip.CongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) } - if got, want := cc, tc.cc; got != want { + + got, want := cc, oldCC + // If SetTransportProtocolOption is expected to succeed + // then the returned value for congestion control should + // match the one specified in the + // SetTransportProtocolOption call above, else it should + // be what it was before the call to + // SetTransportProtocolOption. + if tc.err == nil { + want = tc.cc + } + if got != want { t.Fatalf("got congestion control: %v, want: %v", got, want) } }) } } -func TestAvailableCongestionControl(t *testing.T) { +func TestStackAvailableCongestionControl(t *testing.T) { c := context.New(t, 1500) defer c.Cleanup() s := c.Stack() // Query permitted congestion control algorithms. - var aCC tcp.AvailableCongestionControlOption + var aCC tcpip.AvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &aCC); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &aCC, err) } - if got, want := aCC, tcp.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := aCC, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) } } -func TestSetAvailableCongestionControl(t *testing.T) { +func TestStackSetAvailableCongestionControl(t *testing.T) { c := context.New(t, 1500) defer c.Cleanup() s := c.Stack() // Setting AvailableCongestionControlOption should fail. - aCC := tcp.AvailableCongestionControlOption("xyz") + aCC := tcpip.AvailableCongestionControlOption("xyz") if err := s.SetTransportProtocolOption(tcp.ProtocolNumber, &aCC); err == nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = nil, want non-nil", tcp.ProtocolNumber, &aCC) } // Verify that we still get the expected list of congestion control options. - var cc tcp.AvailableCongestionControlOption + var cc tcpip.AvailableCongestionControlOption if err := s.TransportProtocolOption(tcp.ProtocolNumber, &cc); err != nil { t.Fatalf("s.TransportProtocolOption(%v, %v) = %v", tcp.ProtocolNumber, &cc, err) } - if got, want := cc, tcp.AvailableCongestionControlOption("reno cubic"); got != want { - t.Fatalf("got tcp.AvailableCongestionControlOption: %v, want: %v", got, want) + if got, want := cc, tcpip.AvailableCongestionControlOption("reno cubic"); got != want { + t.Fatalf("got tcpip.AvailableCongestionControlOption: %v, want: %v", got, want) + } +} + +func TestEndpointSetCongestionControl(t *testing.T) { + testCases := []struct { + cc tcpip.CongestionControlOption + err *tcpip.Error + }{ + {"reno", nil}, + {"cubic", nil}, + {"blahblah", tcpip.ErrNoSuchFile}, + } + + for _, connected := range []bool{false, true} { + for _, tc := range testCases { + t.Run(fmt.Sprintf("SetSockOpt(.., %v) w/ connected = %v", tc.cc, connected), func(t *testing.T) { + c := context.New(t, 1500) + defer c.Cleanup() + + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.Stack().NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + t.Fatalf("NewEndpoint failed: %v", err) + } + + var oldCC tcpip.CongestionControlOption + if err := c.EP.GetSockOpt(&oldCC); err != nil { + t.Fatalf("c.EP.SockOpt(%v) = %v", &oldCC, err) + } + + if connected { + c.Connect(789 /* iss */, 32768 /* rcvWnd */, nil) + } + + if err := c.EP.SetSockOpt(tc.cc); err != tc.err { + t.Fatalf("c.EP.SetSockOpt(%v) = %v, want %v", tc.cc, err, tc.err) + } + + var cc tcpip.CongestionControlOption + if err := c.EP.GetSockOpt(&cc); err != nil { + t.Fatalf("c.EP.SockOpt(%v) = %v", &cc, err) + } + + got, want := cc, oldCC + // If SetSockOpt is expected to succeed then the + // returned value for congestion control should match + // the one specified in the SetSockOpt above, else it + // should be what it was before the call to SetSockOpt. + if tc.err == nil { + want = tc.cc + } + if got != want { + t.Fatalf("got congestion control: %v, want: %v", got, want) + } + }) + } } } func enableCUBIC(t *testing.T, c *context.Context) { t.Helper() - opt := tcp.CongestionControlOption("cubic") + opt := tcpip.CongestionControlOption("cubic") if err := c.Stack().SetTransportProtocolOption(tcp.ProtocolNumber, opt); err != nil { t.Fatalf("c.s.SetTransportProtocolOption(tcp.ProtocolNumber, %v = %v", opt, err) } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 69a43b6f4..a4d89e24d 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -520,35 +520,21 @@ func (c *Context) CreateConnected(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf c.CreateConnectedWithRawOptions(iss, rcvWnd, epRcvBuf, nil) } -// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends -// the specified option bytes as the Option field in the initial SYN packet. +// Connect performs the 3-way handshake for c.EP with the provided Initial +// Sequence Number (iss) and receive window(rcvWnd) and any options if +// specified. // // It also sets the receive buffer for the endpoint to the specified // value in epRcvBuf. -func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { - // Create TCP endpoint. - var err *tcpip.Error - c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) - if err != nil { - c.t.Fatalf("NewEndpoint failed: %v", err) - } - if got, want := tcp.EndpointState(c.EP.State()), tcp.StateInitial; got != want { - c.t.Errorf("Unexpected endpoint state: want %v, got %v", want, got) - } - - if epRcvBuf != nil { - if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { - c.t.Fatalf("SetSockOpt failed failed: %v", err) - } - } - +// +// PreCondition: c.EP must already be created. +func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) { // Start connection attempt. waitEntry, notifyCh := waiter.NewChannelEntry(nil) c.WQ.EventRegister(&waitEntry, waiter.EventOut) defer c.WQ.EventUnregister(&waitEntry) - err = c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}) - if err != tcpip.ErrConnectStarted { + if err := c.EP.Connect(tcpip.FullAddress{Addr: TestAddr, Port: TestPort}); err != tcpip.ErrConnectStarted { c.t.Fatalf("Unexpected return value from Connect: %v", err) } @@ -590,8 +576,7 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. // Wait for connection to be established. select { case <-notifyCh: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { + if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { c.t.Fatalf("Unexpected error when connecting: %v", err) } case <-time.After(1 * time.Second): @@ -604,6 +589,27 @@ func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum. c.Port = tcpHdr.SourcePort() } +// CreateConnectedWithRawOptions creates a connected TCP endpoint and sends +// the specified option bytes as the Option field in the initial SYN packet. +// +// It also sets the receive buffer for the endpoint to the specified +// value in epRcvBuf. +func (c *Context) CreateConnectedWithRawOptions(iss seqnum.Value, rcvWnd seqnum.Size, epRcvBuf *tcpip.ReceiveBufferSizeOption, options []byte) { + // Create TCP endpoint. + var err *tcpip.Error + c.EP, err = c.s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, &c.WQ) + if err != nil { + c.t.Fatalf("NewEndpoint failed: %v", err) + } + + if epRcvBuf != nil { + if err := c.EP.SetSockOpt(*epRcvBuf); err != nil { + c.t.Fatalf("SetSockOpt failed failed: %v", err) + } + } + c.Connect(iss, rcvWnd, options) +} + // RawEndpoint is just a small wrapper around a TCP endpoint's state to make // sending data and ACK packets easy while being able to manipulate the sequence // numbers and timestamp values as needed. diff --git a/test/syscalls/linux/socket_ip_tcp_generic.cc b/test/syscalls/linux/socket_ip_tcp_generic.cc index 5b198f49d..0b76280a7 100644 --- a/test/syscalls/linux/socket_ip_tcp_generic.cc +++ b/test/syscalls/linux/socket_ip_tcp_generic.cc @@ -592,5 +592,109 @@ TEST_P(TCPSocketPairTest, MsgTruncMsgPeek) { EXPECT_EQ(0, memcmp(received_data2, sent_data, sizeof(sent_data))); } +TEST_P(TCPSocketPairTest, SetCongestionControlSucceedsForSupported) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + // Netstack only supports reno & cubic so we only test these two values here. + { + const char kSetCC[kTcpCaNameMax] = "reno"; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &kSetCC, strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax]; + memset(got_cc, '1', sizeof(got_cc)); + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC))); + } + { + const char kSetCC[kTcpCaNameMax] = "cubic"; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &kSetCC, strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax]; + memset(got_cc, '1', sizeof(got_cc)); + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC))); + } +} + +TEST_P(TCPSocketPairTest, SetGetTCPCongestionShortReadBuffer) { + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + { + // Verify that getsockopt/setsockopt work with buffers smaller than + // kTcpCaNameMax. + const char kSetCC[] = "cubic"; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &kSetCC, strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[sizeof(kSetCC)]; + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(got_cc))); + } +} + +TEST_P(TCPSocketPairTest, SetGetTCPCongestionLargeReadBuffer) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + { + // Verify that getsockopt works with buffers larger than + // kTcpCaNameMax. + const char kSetCC[] = "cubic"; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &kSetCC, strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax + 5]; + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + // Linux copies the minimum of kTcpCaNameMax or the length of the passed in + // buffer and sets optlen to the number of bytes actually copied + // irrespective of the actual length of the congestion control name. + EXPECT_EQ(kTcpCaNameMax, optlen); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC))); + } +} + +TEST_P(TCPSocketPairTest, SetCongestionControlFailsForUnsupported) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + auto sockets = ASSERT_NO_ERRNO_AND_VALUE(NewSocketPair()); + char old_cc[kTcpCaNameMax]; + socklen_t optlen; + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &old_cc, &optlen), + SyscallSucceedsWithValue(0)); + + const char kSetCC[] = "invalid_ca_cc"; + ASSERT_THAT(setsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &kSetCC, strlen(kSetCC)), + SyscallFailsWithErrno(ENOENT)); + + char got_cc[kTcpCaNameMax]; + ASSERT_THAT(getsockopt(sockets->first_fd(), IPPROTO_TCP, TCP_CONGESTION, + &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(old_cc))); +} + } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/tcp_socket.cc b/test/syscalls/linux/tcp_socket.cc index e3f9f9f9d..e95b644ac 100644 --- a/test/syscalls/linux/tcp_socket.cc +++ b/test/syscalls/linux/tcp_socket.cc @@ -751,6 +751,133 @@ TEST_P(SimpleTcpSocketTest, NonBlockingConnectRefused) { EXPECT_THAT(close(s.release()), SyscallSucceeds()); } +// Test that setting a supported congestion control algorithm succeeds for an +// unconnected TCP socket +TEST_P(SimpleTcpSocketTest, SetCongestionControlSucceedsForSupported) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + { + const char kSetCC[kTcpCaNameMax] = "reno"; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC, + strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax]; + memset(got_cc, '1', sizeof(got_cc)); + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + // We ignore optlen here as the linux kernel sets optlen to the lower of the + // size of the buffer passed in or kTcpCaNameMax and not the length of the + // congestion control algorithm's actual name. + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kTcpCaNameMax))); + } + { + const char kSetCC[kTcpCaNameMax] = "cubic"; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC, + strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax]; + memset(got_cc, '1', sizeof(got_cc)); + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + // We ignore optlen here as the linux kernel sets optlen to the lower of the + // size of the buffer passed in or kTcpCaNameMax and not the length of the + // congestion control algorithm's actual name. + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kTcpCaNameMax))); + } +} + +// This test verifies that a getsockopt(...TCP_CONGESTION) behaviour is +// consistent between linux and gvisor when the passed in buffer is smaller than +// kTcpCaNameMax. +TEST_P(SimpleTcpSocketTest, SetGetTCPCongestionShortReadBuffer) { + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + { + // Verify that getsockopt/setsockopt work with buffers smaller than + // kTcpCaNameMax. + const char kSetCC[] = "cubic"; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC, + strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[sizeof(kSetCC)]; + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + EXPECT_EQ(sizeof(got_cc), optlen); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(got_cc))); + } +} + +// This test verifies that a getsockopt(...TCP_CONGESTION) behaviour is +// consistent between linux and gvisor when the passed in buffer is larger than +// kTcpCaNameMax. +TEST_P(SimpleTcpSocketTest, SetGetTCPCongestionLargeReadBuffer) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + { + // Verify that getsockopt works with buffers larger than + // kTcpCaNameMax. + const char kSetCC[] = "cubic"; + ASSERT_THAT(setsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &kSetCC, + strlen(kSetCC)), + SyscallSucceedsWithValue(0)); + + char got_cc[kTcpCaNameMax + 5]; + socklen_t optlen = sizeof(got_cc); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + // Linux copies the minimum of kTcpCaNameMax or the length of the passed in + // buffer and sets optlen to the number of bytes actually copied + // irrespective of the actual length of the congestion control name. + EXPECT_EQ(kTcpCaNameMax, optlen); + EXPECT_EQ(0, memcmp(got_cc, kSetCC, sizeof(kSetCC))); + } +} + +// Test that setting an unsupported congestion control algorithm fails for an +// unconnected TCP socket. +TEST_P(SimpleTcpSocketTest, SetCongestionControlFailsForUnsupported) { + // This is Linux's net/tcp.h TCP_CA_NAME_MAX. + const int kTcpCaNameMax = 16; + + FileDescriptor s = + ASSERT_NO_ERRNO_AND_VALUE(Socket(GetParam(), SOCK_STREAM, IPPROTO_TCP)); + char old_cc[kTcpCaNameMax]; + socklen_t optlen = sizeof(old_cc); + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &old_cc, &optlen), + SyscallSucceedsWithValue(0)); + + const char kSetCC[] = "invalid_ca_kSetCC"; + ASSERT_THAT( + setsockopt(s.get(), SOL_TCP, TCP_CONGESTION, &kSetCC, strlen(kSetCC)), + SyscallFailsWithErrno(ENOENT)); + + char got_cc[kTcpCaNameMax]; + ASSERT_THAT( + getsockopt(s.get(), IPPROTO_TCP, TCP_CONGESTION, &got_cc, &optlen), + SyscallSucceedsWithValue(0)); + // We ignore optlen here as the linux kernel sets optlen to the lower of the + // size of the buffer passed in or kTcpCaNameMax and not the length of the + // congestion control algorithm's actual name. + EXPECT_EQ(0, memcmp(got_cc, old_cc, sizeof(kTcpCaNameMax))); +} + INSTANTIATE_TEST_SUITE_P(AllInetTests, SimpleTcpSocketTest, ::testing::Values(AF_INET, AF_INET6)); -- cgit v1.2.3