summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
authorgVisor bot <gvisor-bot@google.com>2019-06-10 22:42:41 +0000
committergVisor bot <gvisor-bot@google.com>2019-06-10 22:42:41 +0000
commit8390a8227b571e82c42e3e90aa28a7b86f7e3f9b (patch)
tree8bfad5169182b7ba1c6ed5f3df0279729cc200b0 /pkg/sentry/socket
parent4f56f1bf2248bb17da8b269b4191218d85ce6587 (diff)
parenta00157cc0e216a9829f2659ce35c856a22aa5ba2 (diff)
Merge a00157cc (automated)
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go66
-rwxr-xr-xpkg/sentry/socket/epsocket/epsocket_state_autogen.go2
-rw-r--r--pkg/sentry/socket/epsocket/provider.go9
-rw-r--r--pkg/sentry/socket/hostinet/socket.go61
-rwxr-xr-xpkg/sentry/socket/netlink/netlink_state_autogen.go2
-rw-r--r--pkg/sentry/socket/netlink/provider.go9
-rw-r--r--pkg/sentry/socket/netlink/socket.go17
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go31
-rw-r--r--pkg/sentry/socket/socket.go21
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go29
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go20
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go26
-rw-r--r--pkg/sentry/socket/unix/unix.go74
-rwxr-xr-xpkg/sentry/socket/unix/unix_state_autogen.go4
14 files changed, 264 insertions, 107 deletions
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index de4b963da..f67451179 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -44,7 +44,6 @@ import (
ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
"gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket"
- "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.googlesource.com/gvisor/pkg/sentry/unimpl"
"gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
"gvisor.googlesource.com/gvisor/pkg/syserr"
@@ -52,6 +51,7 @@ import (
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
"gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
"gvisor.googlesource.com/gvisor/pkg/waiter"
)
@@ -227,7 +227,8 @@ type SocketOperations struct {
family int
Endpoint tcpip.Endpoint
- skType transport.SockType
+ skType linux.SockType
+ protocol int
// readMu protects access to the below fields.
readMu sync.Mutex `state:"nosave"`
@@ -252,8 +253,8 @@ type SocketOperations struct {
}
// New creates a new endpoint socket.
-func New(t *kernel.Task, family int, skType transport.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
- if skType == transport.SockStream {
+func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
+ if skType == linux.SOCK_STREAM {
if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -266,6 +267,7 @@ func New(t *kernel.Task, family int, skType transport.SockType, queue *waiter.Qu
family: family,
Endpoint: endpoint,
skType: skType,
+ protocol: protocol,
}), nil
}
@@ -550,7 +552,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
}
- ns, err := New(t, s.family, s.skType, wq, ep)
+ ns, err := New(t, s.family, s.skType, s.protocol, wq, ep)
if err != nil {
return 0, nil, 0, err
}
@@ -578,7 +580,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits())
- t.Kernel().RecordSocket(ns, s.family)
+ t.Kernel().RecordSocket(ns)
return fd, addr, addrLen, syserr.FromError(e)
}
@@ -637,7 +639,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (
// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
// sockets backed by a commonEndpoint.
-func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
@@ -663,7 +665,7 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int,
}
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType transport.SockType, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_TYPE:
@@ -2281,3 +2283,51 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 {
}
return rv
}
+
+// State implements socket.Socket.State. State translates the internal state
+// returned by netstack to values defined by Linux.
+func (s *SocketOperations) State() uint32 {
+ if s.family != linux.AF_INET && s.family != linux.AF_INET6 {
+ // States not implemented for this socket's family.
+ return 0
+ }
+
+ if !s.isPacketBased() {
+ // TCP socket.
+ switch tcp.EndpointState(s.Endpoint.State()) {
+ case tcp.StateEstablished:
+ return linux.TCP_ESTABLISHED
+ case tcp.StateSynSent:
+ return linux.TCP_SYN_SENT
+ case tcp.StateSynRecv:
+ return linux.TCP_SYN_RECV
+ case tcp.StateFinWait1:
+ return linux.TCP_FIN_WAIT1
+ case tcp.StateFinWait2:
+ return linux.TCP_FIN_WAIT2
+ case tcp.StateTimeWait:
+ return linux.TCP_TIME_WAIT
+ case tcp.StateClose, tcp.StateInitial, tcp.StateBound, tcp.StateConnecting, tcp.StateError:
+ return linux.TCP_CLOSE
+ case tcp.StateCloseWait:
+ return linux.TCP_CLOSE_WAIT
+ case tcp.StateLastAck:
+ return linux.TCP_LAST_ACK
+ case tcp.StateListen:
+ return linux.TCP_LISTEN
+ case tcp.StateClosing:
+ return linux.TCP_CLOSING
+ default:
+ // Internal or unknown state.
+ return 0
+ }
+ }
+
+ // TODO(b/112063468): Export states for UDP, ICMP, and raw sockets.
+ return 0
+}
+
+// Type implements socket.Socket.Type.
+func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) {
+ return s.family, s.skType, s.protocol
+}
diff --git a/pkg/sentry/socket/epsocket/epsocket_state_autogen.go b/pkg/sentry/socket/epsocket/epsocket_state_autogen.go
index 4b407b796..37e0276fb 100755
--- a/pkg/sentry/socket/epsocket/epsocket_state_autogen.go
+++ b/pkg/sentry/socket/epsocket/epsocket_state_autogen.go
@@ -14,6 +14,7 @@ func (x *SocketOperations) save(m state.Map) {
m.Save("family", &x.family)
m.Save("Endpoint", &x.Endpoint)
m.Save("skType", &x.skType)
+ m.Save("protocol", &x.protocol)
m.Save("readView", &x.readView)
m.Save("readCM", &x.readCM)
m.Save("sender", &x.sender)
@@ -29,6 +30,7 @@ func (x *SocketOperations) load(m state.Map) {
m.Load("family", &x.family)
m.Load("Endpoint", &x.Endpoint)
m.Load("skType", &x.skType)
+ m.Load("protocol", &x.protocol)
m.Load("readView", &x.readView)
m.Load("readCM", &x.readCM)
m.Load("sender", &x.sender)
diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go
index ec930d8d5..516582828 100644
--- a/pkg/sentry/socket/epsocket/provider.go
+++ b/pkg/sentry/socket/epsocket/provider.go
@@ -23,7 +23,6 @@ import (
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket"
- "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.googlesource.com/gvisor/pkg/syserr"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/tcpip/header"
@@ -42,7 +41,7 @@ type provider struct {
// getTransportProtocol figures out transport protocol. Currently only TCP,
// UDP, and ICMP are supported.
-func getTransportProtocol(ctx context.Context, stype transport.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) {
+func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) {
switch stype {
case linux.SOCK_STREAM:
if protocol != 0 && protocol != syscall.IPPROTO_TCP {
@@ -80,7 +79,7 @@ func getTransportProtocol(ctx context.Context, stype transport.SockType, protoco
}
// Socket creates a new socket object for the AF_INET or AF_INET6 family.
-func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
// Fail right away if we don't have a stack.
stack := t.NetworkContext()
if stack == nil {
@@ -112,11 +111,11 @@ func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int
return nil, syserr.TranslateNetstackError(e)
}
- return New(t, p.family, stype, wq, ep)
+ return New(t, p.family, stype, protocol, wq, ep)
}
// Pair just returns nil sockets (not supported).
-func (*provider) Pair(*kernel.Task, transport.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
+func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
return nil, nil, nil
}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 41f9693bb..c62c8d8f1 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -19,7 +19,9 @@ import (
"syscall"
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
"gvisor.googlesource.com/gvisor/pkg/fdnotifier"
+ "gvisor.googlesource.com/gvisor/pkg/log"
"gvisor.googlesource.com/gvisor/pkg/sentry/context"
"gvisor.googlesource.com/gvisor/pkg/sentry/fs"
"gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
@@ -28,7 +30,6 @@ import (
ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
"gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket"
- "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
"gvisor.googlesource.com/gvisor/pkg/syserr"
"gvisor.googlesource.com/gvisor/pkg/syserror"
@@ -55,15 +56,22 @@ type socketOperations struct {
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
socket.SendReceiveTimeout
- family int // Read-only.
- fd int // must be O_NONBLOCK
- queue waiter.Queue
+ family int // Read-only.
+ stype linux.SockType // Read-only.
+ protocol int // Read-only.
+ fd int // must be O_NONBLOCK
+ queue waiter.Queue
}
var _ = socket.Socket(&socketOperations{})
-func newSocketFile(ctx context.Context, family int, fd int, nonblock bool) (*fs.File, *syserr.Error) {
- s := &socketOperations{family: family, fd: fd}
+func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) {
+ s := &socketOperations{
+ family: family,
+ stype: stype,
+ protocol: protocol,
+ fd: fd,
+ }
if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
return nil, syserr.FromError(err)
}
@@ -221,7 +229,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr)
}
- f, err := newSocketFile(t, s.family, fd, flags&syscall.SOCK_NONBLOCK != 0)
+ f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0)
if err != nil {
syscall.Close(fd)
return 0, nil, 0, err
@@ -232,7 +240,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
}
kfd, kerr := t.FDMap().NewFDFrom(0, f, fdFlags, t.ThreadGroup().Limits())
- t.Kernel().RecordSocket(f, s.family)
+ t.Kernel().RecordSocket(f)
return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr)
}
@@ -519,12 +527,39 @@ func translateIOSyscallError(err error) error {
return err
}
+// State implements socket.Socket.State.
+func (s *socketOperations) State() uint32 {
+ info := linux.TCPInfo{}
+ buf, err := getsockopt(s.fd, syscall.SOL_TCP, syscall.TCP_INFO, linux.SizeOfTCPInfo)
+ if err != nil {
+ if err != syscall.ENOPROTOOPT {
+ log.Warningf("Failed to get TCP socket info from %+v: %v", s, err)
+ }
+ // For non-TCP sockets, silently ignore the failure.
+ return 0
+ }
+ if len(buf) != linux.SizeOfTCPInfo {
+ // Unmarshal below will panic if getsockopt returns a buffer of
+ // unexpected size.
+ log.Warningf("Failed to get TCP socket info from %+v: getsockopt(2) returned %d bytes, expecting %d bytes.", s, len(buf), linux.SizeOfTCPInfo)
+ return 0
+ }
+
+ binary.Unmarshal(buf, usermem.ByteOrder, &info)
+ return uint32(info.State)
+}
+
+// Type implements socket.Socket.Type.
+func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) {
+ return s.family, s.stype, s.protocol
+}
+
type socketProvider struct {
family int
}
// Socket implements socket.Provider.Socket.
-func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) {
// Check that we are using the host network stack.
stack := t.NetworkContext()
if stack == nil {
@@ -535,7 +570,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p
}
// Only accept TCP and UDP.
- stype := int(stypeflags) & linux.SOCK_TYPE_MASK
+ stype := stypeflags & linux.SOCK_TYPE_MASK
switch stype {
case syscall.SOCK_STREAM:
switch protocol {
@@ -558,15 +593,15 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p
// Conservatively ignore all flags specified by the application and add
// SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0
// to simplify the syscall filters, since 0 and IPPROTO_* are equivalent.
- fd, err := syscall.Socket(p.family, stype|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ fd, err := syscall.Socket(p.family, int(stype)|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
if err != nil {
return nil, syserr.FromError(err)
}
- return newSocketFile(t, p.family, fd, stypeflags&syscall.SOCK_NONBLOCK != 0)
+ return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&syscall.SOCK_NONBLOCK != 0)
}
// Pair implements socket.Provider.Pair.
-func (p *socketProvider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
// Not supported by AF_INET/AF_INET6.
return nil, nil, nil
}
diff --git a/pkg/sentry/socket/netlink/netlink_state_autogen.go b/pkg/sentry/socket/netlink/netlink_state_autogen.go
index 59d902798..e51a86b19 100755
--- a/pkg/sentry/socket/netlink/netlink_state_autogen.go
+++ b/pkg/sentry/socket/netlink/netlink_state_autogen.go
@@ -12,6 +12,7 @@ func (x *Socket) save(m state.Map) {
m.Save("SendReceiveTimeout", &x.SendReceiveTimeout)
m.Save("ports", &x.ports)
m.Save("protocol", &x.protocol)
+ m.Save("skType", &x.skType)
m.Save("ep", &x.ep)
m.Save("connection", &x.connection)
m.Save("bound", &x.bound)
@@ -24,6 +25,7 @@ func (x *Socket) load(m state.Map) {
m.Load("SendReceiveTimeout", &x.SendReceiveTimeout)
m.Load("ports", &x.ports)
m.Load("protocol", &x.protocol)
+ m.Load("skType", &x.skType)
m.Load("ep", &x.ep)
m.Load("connection", &x.connection)
m.Load("bound", &x.bound)
diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go
index 76cf12fd4..5dc103877 100644
--- a/pkg/sentry/socket/netlink/provider.go
+++ b/pkg/sentry/socket/netlink/provider.go
@@ -22,7 +22,6 @@ import (
"gvisor.googlesource.com/gvisor/pkg/sentry/fs"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket"
- "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.googlesource.com/gvisor/pkg/syserr"
)
@@ -66,10 +65,10 @@ type socketProvider struct {
}
// Socket implements socket.Provider.Socket.
-func (*socketProvider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
// Netlink sockets must be specified as datagram or raw, but they
// behave the same regardless of type.
- if stype != transport.SockDgram && stype != transport.SockRaw {
+ if stype != linux.SOCK_DGRAM && stype != linux.SOCK_RAW {
return nil, syserr.ErrSocketNotSupported
}
@@ -83,7 +82,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype transport.SockType, protocol
return nil, err
}
- s, err := NewSocket(t, p)
+ s, err := NewSocket(t, stype, p)
if err != nil {
return nil, err
}
@@ -94,7 +93,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype transport.SockType, protocol
}
// Pair implements socket.Provider.Pair by returning an error.
-func (*socketProvider) Pair(*kernel.Task, transport.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
+func (*socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
// Netlink sockets never supports creating socket pairs.
return nil, nil, syserr.ErrNotSupported
}
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index afd06ca33..62659784a 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -80,6 +80,10 @@ type Socket struct {
// protocol is the netlink protocol implementation.
protocol Protocol
+ // skType is the socket type. This is either SOCK_DGRAM or SOCK_RAW for
+ // netlink sockets.
+ skType linux.SockType
+
// ep is a datagram unix endpoint used to buffer messages sent from the
// kernel to userspace. RecvMsg reads messages from this endpoint.
ep transport.Endpoint
@@ -105,7 +109,7 @@ type Socket struct {
var _ socket.Socket = (*Socket)(nil)
// NewSocket creates a new Socket.
-func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) {
+func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) {
// Datagram endpoint used to buffer kernel -> user messages.
ep := transport.NewConnectionless()
@@ -126,6 +130,7 @@ func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) {
return &Socket{
ports: t.Kernel().NetlinkPorts(),
protocol: protocol,
+ skType: skType,
ep: ep,
connection: connection,
sendBufferSize: defaultSendBufferSize,
@@ -616,3 +621,13 @@ func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence,
n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{})
return int64(n), err.ToError()
}
+
+// State implements socket.Socket.State.
+func (s *Socket) State() uint32 {
+ return s.ep.State()
+}
+
+// Type implements socket.Socket.Type.
+func (s *Socket) Type() (family int, skType linux.SockType, protocol int) {
+ return linux.AF_NETLINK, s.skType, s.protocol.Protocol()
+}
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
index 55e0b6665..c22ff1ff0 100644
--- a/pkg/sentry/socket/rpcinet/socket.go
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -32,7 +32,6 @@ import (
"gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/conn"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/notifier"
pb "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
- "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.googlesource.com/gvisor/pkg/sentry/unimpl"
"gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
"gvisor.googlesource.com/gvisor/pkg/syserr"
@@ -54,7 +53,10 @@ type socketOperations struct {
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
socket.SendReceiveTimeout
- family int // Read-only.
+ family int // Read-only.
+ stype linux.SockType // Read-only.
+ protocol int // Read-only.
+
fd uint32 // must be O_NONBLOCK
wq *waiter.Queue
rpcConn *conn.RPCConnection
@@ -70,7 +72,7 @@ type socketOperations struct {
var _ = socket.Socket(&socketOperations{})
// New creates a new RPC socket.
-func newSocketFile(ctx context.Context, stack *Stack, family int, skType int, protocol int) (*fs.File, *syserr.Error) {
+func newSocketFile(ctx context.Context, stack *Stack, family int, skType linux.SockType, protocol int) (*fs.File, *syserr.Error) {
id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(family), Type: int64(skType | syscall.SOCK_NONBLOCK), Protocol: int64(protocol)}}}, false /* ignoreResult */)
<-c
@@ -87,6 +89,8 @@ func newSocketFile(ctx context.Context, stack *Stack, family int, skType int, pr
defer dirent.DecRef()
return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &socketOperations{
family: family,
+ stype: skType,
+ protocol: protocol,
wq: &wq,
fd: fd,
rpcConn: stack.rpcConn,
@@ -333,7 +337,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
if err != nil {
return 0, nil, 0, syserr.FromError(err)
}
- t.Kernel().RecordSocket(file, s.family)
+ t.Kernel().RecordSocket(file)
if peerRequested {
return fd, payload.Address.Address, payload.Address.Length, nil
@@ -830,12 +834,23 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
}
}
+// State implements socket.Socket.State.
+func (s *socketOperations) State() uint32 {
+ // TODO(b/127845868): Define a new rpc to query the socket state.
+ return 0
+}
+
+// Type implements socket.Socket.Type.
+func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) {
+ return s.family, s.stype, s.protocol
+}
+
type socketProvider struct {
family int
}
// Socket implements socket.Provider.Socket.
-func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) {
// Check that we are using the RPC network stack.
stack := t.NetworkContext()
if stack == nil {
@@ -851,7 +866,7 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p
//
// Try to restrict the flags we will accept to minimize backwards
// incompatibility with netstack.
- stype := int(stypeflags) & linux.SOCK_TYPE_MASK
+ stype := stypeflags & linux.SOCK_TYPE_MASK
switch stype {
case syscall.SOCK_STREAM:
switch protocol {
@@ -871,11 +886,11 @@ func (p *socketProvider) Socket(t *kernel.Task, stypeflags transport.SockType, p
return nil, nil
}
- return newSocketFile(t, s, p.family, stype, 0)
+ return newSocketFile(t, s, p.family, stype, protocol)
}
// Pair implements socket.Provider.Pair.
-func (p *socketProvider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
// Not supported by AF_INET/AF_INET6.
return nil, nil, nil
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 9393acd28..d60944b6b 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -116,6 +116,13 @@ type Socket interface {
// SendTimeout gets the current timeout (in ns) for send operations. Zero
// means no timeout, and negative means DONTWAIT.
SendTimeout() int64
+
+ // State returns the current state of the socket, as represented by Linux in
+ // procfs. The returned state value is protocol-specific.
+ State() uint32
+
+ // Type returns the family, socket type and protocol of the socket.
+ Type() (family int, skType linux.SockType, protocol int)
}
// Provider is the interface implemented by providers of sockets for specific
@@ -126,12 +133,12 @@ type Provider interface {
// If a nil Socket _and_ a nil error is returned, it means that the
// protocol is not supported. A non-nil error should only be returned
// if the protocol is supported, but an error occurs during creation.
- Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error)
+ Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error)
// Pair creates a pair of connected sockets.
//
// See Socket for error information.
- Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error)
+ Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error)
}
// families holds a map of all known address families and their providers.
@@ -145,14 +152,14 @@ func RegisterProvider(family int, provider Provider) {
}
// New creates a new socket with the given family, type and protocol.
-func New(t *kernel.Task, family int, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+func New(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
for _, p := range families[family] {
s, err := p.Socket(t, stype, protocol)
if err != nil {
return nil, err
}
if s != nil {
- t.Kernel().RecordSocket(s, family)
+ t.Kernel().RecordSocket(s)
return s, nil
}
}
@@ -162,7 +169,7 @@ func New(t *kernel.Task, family int, stype transport.SockType, protocol int) (*f
// Pair creates a new connected socket pair with the given family, type and
// protocol.
-func Pair(t *kernel.Task, family int, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+func Pair(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
providers, ok := families[family]
if !ok {
return nil, nil, syserr.ErrAddressFamilyNotSupported
@@ -175,8 +182,8 @@ func Pair(t *kernel.Task, family int, stype transport.SockType, protocol int) (*
}
if s1 != nil && s2 != nil {
k := t.Kernel()
- k.RecordSocket(s1, family)
- k.RecordSocket(s2, family)
+ k.RecordSocket(s1)
+ k.RecordSocket(s2)
return s1, s2, nil
}
}
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index 18e492862..db79ac904 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -17,6 +17,7 @@ package transport
import (
"sync"
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/syserr"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/waiter"
@@ -44,7 +45,7 @@ type ConnectingEndpoint interface {
// Type returns the socket type, typically either SockStream or
// SockSeqpacket. The connection attempt must be aborted if this
// value doesn't match the ConnectableEndpoint's type.
- Type() SockType
+ Type() linux.SockType
// GetLocalAddress returns the bound path.
GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
@@ -100,7 +101,7 @@ type connectionedEndpoint struct {
// stype is used by connecting sockets to ensure that they are the
// same type. The value is typically either tcpip.SockSeqpacket or
// tcpip.SockStream.
- stype SockType
+ stype linux.SockType
// acceptedChan is per the TCP endpoint implementation. Note that the
// sockets in this channel are _already in the connected state_, and
@@ -111,7 +112,7 @@ type connectionedEndpoint struct {
}
// NewConnectioned creates a new unbound connectionedEndpoint.
-func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint {
+func NewConnectioned(stype linux.SockType, uid UniqueIDProvider) Endpoint {
return &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
id: uid.UniqueID(),
@@ -121,7 +122,7 @@ func NewConnectioned(stype SockType, uid UniqueIDProvider) Endpoint {
}
// NewPair allocates a new pair of connected unix-domain connectionedEndpoints.
-func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
+func NewPair(stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
a := &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
id: uid.UniqueID(),
@@ -138,7 +139,7 @@ func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit}
- if stype == SockStream {
+ if stype == linux.SOCK_STREAM {
a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}}
b.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q2}}
} else {
@@ -162,7 +163,7 @@ func NewPair(stype SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
// NewExternal creates a new externally backed Endpoint. It behaves like a
// socketpair.
-func NewExternal(stype SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint {
+func NewExternal(stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint {
return &connectionedEndpoint{
baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected},
id: uid.UniqueID(),
@@ -177,7 +178,7 @@ func (e *connectionedEndpoint) ID() uint64 {
}
// Type implements ConnectingEndpoint.Type and Endpoint.Type.
-func (e *connectionedEndpoint) Type() SockType {
+func (e *connectionedEndpoint) Type() linux.SockType {
return e.stype
}
@@ -293,7 +294,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur
}
writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit}
- if e.stype == SockStream {
+ if e.stype == linux.SOCK_STREAM {
ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
} else {
ne.receiver = &queueReceiver{readQueue: writeQueue}
@@ -308,7 +309,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ce ConnectingEndpoint, retur
writeQueue: writeQueue,
}
readQueue.IncRef()
- if e.stype == SockStream {
+ if e.stype == linux.SOCK_STREAM {
returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected)
} else {
returnConnect(&queueReceiver{readQueue: readQueue}, connected)
@@ -428,7 +429,7 @@ func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syser
func (e *connectionedEndpoint) SendMsg(data [][]byte, c ControlMessages, to BoundEndpoint) (uintptr, *syserr.Error) {
// Stream sockets do not support specifying the endpoint. Seqpacket
// sockets ignore the passed endpoint.
- if e.stype == SockStream && to != nil {
+ if e.stype == linux.SOCK_STREAM && to != nil {
return 0, syserr.ErrNotSupported
}
return e.baseEndpoint.SendMsg(data, c, to)
@@ -458,3 +459,11 @@ func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask
return ready
}
+
+// State implements socket.Socket.State.
+func (e *connectionedEndpoint) State() uint32 {
+ if e.Connected() {
+ return linux.SS_CONNECTED
+ }
+ return linux.SS_UNCONNECTED
+}
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index 43ff875e4..81ebfba10 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -15,6 +15,7 @@
package transport
import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/syserr"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/waiter"
@@ -118,8 +119,8 @@ func (e *connectionlessEndpoint) SendMsg(data [][]byte, c ControlMessages, to Bo
}
// Type implements Endpoint.Type.
-func (e *connectionlessEndpoint) Type() SockType {
- return SockDgram
+func (e *connectionlessEndpoint) Type() linux.SockType {
+ return linux.SOCK_DGRAM
}
// Connect attempts to connect directly to server.
@@ -194,3 +195,18 @@ func (e *connectionlessEndpoint) Readiness(mask waiter.EventMask) waiter.EventMa
return ready
}
+
+// State implements socket.Socket.State.
+func (e *connectionlessEndpoint) State() uint32 {
+ e.Lock()
+ defer e.Unlock()
+
+ switch {
+ case e.isBound():
+ return linux.SS_UNCONNECTED
+ case e.Connected():
+ return linux.SS_CONNECTING
+ default:
+ return linux.SS_DISCONNECTING
+ }
+}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 37d82bb6b..5c55c529e 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -19,6 +19,7 @@ import (
"sync"
"sync/atomic"
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
"gvisor.googlesource.com/gvisor/pkg/syserr"
"gvisor.googlesource.com/gvisor/pkg/tcpip"
"gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
@@ -28,21 +29,6 @@ import (
// initialLimit is the starting limit for the socket buffers.
const initialLimit = 16 * 1024
-// A SockType is a type (as opposed to family) of sockets. These are enumerated
-// in the syscall package as syscall.SOCK_* constants.
-type SockType int
-
-const (
- // SockStream corresponds to syscall.SOCK_STREAM.
- SockStream SockType = 1
- // SockDgram corresponds to syscall.SOCK_DGRAM.
- SockDgram SockType = 2
- // SockRaw corresponds to syscall.SOCK_RAW.
- SockRaw SockType = 3
- // SockSeqpacket corresponds to syscall.SOCK_SEQPACKET.
- SockSeqpacket SockType = 5
-)
-
// A RightsControlMessage is a control message containing FDs.
type RightsControlMessage interface {
// Clone returns a copy of the RightsControlMessage.
@@ -175,7 +161,7 @@ type Endpoint interface {
// Type return the socket type, typically either SockStream, SockDgram
// or SockSeqpacket.
- Type() SockType
+ Type() linux.SockType
// GetLocalAddress returns the address to which the endpoint is bound.
GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
@@ -191,6 +177,10 @@ type Endpoint interface {
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// tcpip.*Option types.
GetSockOpt(opt interface{}) *tcpip.Error
+
+ // State returns the current state of the socket, as represented by Linux in
+ // procfs.
+ State() uint32
}
// A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket
@@ -625,7 +615,7 @@ type connectedEndpoint struct {
GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
// Type implements Endpoint.Type.
- Type() SockType
+ Type() linux.SockType
}
writeQueue *queue
@@ -649,7 +639,7 @@ func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages,
}
truncate := false
- if e.endpoint.Type() == SockStream {
+ if e.endpoint.Type() == linux.SOCK_STREAM {
// Since stream sockets don't preserve message boundaries, we
// can write only as much of the message as fits in the queue.
truncate = true
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 388cc0d8b..b07e8d67b 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -17,6 +17,7 @@
package unix
import (
+ "fmt"
"strings"
"syscall"
@@ -55,22 +56,22 @@ type SocketOperations struct {
refs.AtomicRefCount
socket.SendReceiveTimeout
- ep transport.Endpoint
- isPacket bool
+ ep transport.Endpoint
+ stype linux.SockType
}
// New creates a new unix socket.
-func New(ctx context.Context, endpoint transport.Endpoint, isPacket bool) *fs.File {
+func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File {
dirent := socket.NewDirent(ctx, unixSocketDevice)
defer dirent.DecRef()
- return NewWithDirent(ctx, dirent, endpoint, isPacket, fs.FileFlags{Read: true, Write: true})
+ return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true})
}
// NewWithDirent creates a new unix socket using an existing dirent.
-func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, isPacket bool, flags fs.FileFlags) *fs.File {
+func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, stype linux.SockType, flags fs.FileFlags) *fs.File {
return fs.NewFile(ctx, d, flags, &SocketOperations{
- ep: ep,
- isPacket: isPacket,
+ ep: ep,
+ stype: stype,
})
}
@@ -88,6 +89,18 @@ func (s *SocketOperations) Release() {
s.DecRef()
}
+func (s *SocketOperations) isPacket() bool {
+ switch s.stype {
+ case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ return true
+ case linux.SOCK_STREAM:
+ return false
+ default:
+ // We shouldn't have allowed any other socket types during creation.
+ panic(fmt.Sprintf("Invalid socket type %d", s.stype))
+ }
+}
+
// Endpoint extracts the transport.Endpoint.
func (s *SocketOperations) Endpoint() transport.Endpoint {
return s.ep
@@ -193,7 +206,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
}
- ns := New(t, ep, s.isPacket)
+ ns := New(t, ep, s.stype)
defer ns.DecRef()
if flags&linux.SOCK_NONBLOCK != 0 {
@@ -221,7 +234,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
return 0, nil, 0, syserr.FromError(e)
}
- t.Kernel().RecordSocket(ns, linux.AF_UNIX)
+ t.Kernel().RecordSocket(ns)
return fd, addr, addrLen, nil
}
@@ -487,6 +500,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
waitAll := flags&linux.MSG_WAITALL != 0
+ isPacket := s.isPacket()
// Calculate the number of FDs for which we have space and if we are
// requesting credentials.
@@ -528,8 +542,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
msgFlags |= linux.MSG_CTRUNC
}
- if err != nil || dontWait || !waitAll || s.isPacket || n >= dst.NumBytes() {
- if s.isPacket && n < int64(r.MsgSize) {
+ if err != nil || dontWait || !waitAll || isPacket || n >= dst.NumBytes() {
+ if isPacket && n < int64(r.MsgSize) {
msgFlags |= linux.MSG_TRUNC
}
@@ -570,11 +584,11 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
total += n
}
- if err != nil || !waitAll || s.isPacket || n >= dst.NumBytes() {
+ if err != nil || !waitAll || isPacket || n >= dst.NumBytes() {
if total > 0 {
err = nil
}
- if s.isPacket && n < int64(r.MsgSize) {
+ if isPacket && n < int64(r.MsgSize) {
msgFlags |= linux.MSG_TRUNC
}
return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
@@ -596,11 +610,22 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
}
+// State implements socket.Socket.State.
+func (s *SocketOperations) State() uint32 {
+ return s.ep.State()
+}
+
+// Type implements socket.Socket.Type.
+func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) {
+ // Unix domain sockets always have a protocol of 0.
+ return linux.AF_UNIX, s.stype, 0
+}
+
// provider is a unix domain socket provider.
type provider struct{}
// Socket returns a new unix domain socket.
-func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *syserr.Error) {
+func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
// Check arguments.
if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
return nil, syserr.ErrProtocolNotSupported
@@ -608,43 +633,36 @@ func (*provider) Socket(t *kernel.Task, stype transport.SockType, protocol int)
// Create the endpoint and socket.
var ep transport.Endpoint
- var isPacket bool
switch stype {
case linux.SOCK_DGRAM:
- isPacket = true
ep = transport.NewConnectionless()
- case linux.SOCK_SEQPACKET:
- isPacket = true
- fallthrough
- case linux.SOCK_STREAM:
+ case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
ep = transport.NewConnectioned(stype, t.Kernel())
default:
return nil, syserr.ErrInvalidArgument
}
- return New(t, ep, isPacket), nil
+ return New(t, ep, stype), nil
}
// Pair creates a new pair of AF_UNIX connected sockets.
-func (*provider) Pair(t *kernel.Task, stype transport.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
// Check arguments.
if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
return nil, nil, syserr.ErrProtocolNotSupported
}
- var isPacket bool
switch stype {
- case linux.SOCK_STREAM:
- case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
- isPacket = true
+ case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ // Ok
default:
return nil, nil, syserr.ErrInvalidArgument
}
// Create the endpoints and sockets.
ep1, ep2 := transport.NewPair(stype, t.Kernel())
- s1 := New(t, ep1, isPacket)
- s2 := New(t, ep2, isPacket)
+ s1 := New(t, ep1, stype)
+ s2 := New(t, ep2, stype)
return s1, s2, nil
}
diff --git a/pkg/sentry/socket/unix/unix_state_autogen.go b/pkg/sentry/socket/unix/unix_state_autogen.go
index 6f8d24b44..605f6fb59 100755
--- a/pkg/sentry/socket/unix/unix_state_autogen.go
+++ b/pkg/sentry/socket/unix/unix_state_autogen.go
@@ -12,7 +12,7 @@ func (x *SocketOperations) save(m state.Map) {
m.Save("AtomicRefCount", &x.AtomicRefCount)
m.Save("SendReceiveTimeout", &x.SendReceiveTimeout)
m.Save("ep", &x.ep)
- m.Save("isPacket", &x.isPacket)
+ m.Save("stype", &x.stype)
}
func (x *SocketOperations) afterLoad() {}
@@ -20,7 +20,7 @@ func (x *SocketOperations) load(m state.Map) {
m.Load("AtomicRefCount", &x.AtomicRefCount)
m.Load("SendReceiveTimeout", &x.SendReceiveTimeout)
m.Load("ep", &x.ep)
- m.Load("isPacket", &x.isPacket)
+ m.Load("stype", &x.stype)
}
func init() {