summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/netlink/socket.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/netlink/socket.go')
-rw-r--r--pkg/sentry/socket/netlink/socket.go72
1 files changed, 43 insertions, 29 deletions
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 2ca02567d..81f34c5a2 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -58,6 +58,8 @@ var errNoFilter = syserr.New("no filter attached", linux.ENOENT)
// netlinkSocketDevice is the netlink socket virtual device.
var netlinkSocketDevice = device.NewAnonDevice()
+// LINT.IfChange
+
// Socket is the base socket type for netlink sockets.
//
// This implementation only supports userspace sending and receiving messages
@@ -74,6 +76,14 @@ type Socket struct {
fsutil.FileNoSplice `state:"nosave"`
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ socketOpsCommon
+}
+
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
socket.SendReceiveTimeout
// ports provides netlink port allocation.
@@ -140,17 +150,19 @@ func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socke
}
return &Socket{
- ports: t.Kernel().NetlinkPorts(),
- protocol: protocol,
- skType: skType,
- ep: ep,
- connection: connection,
- sendBufferSize: defaultSendBufferSize,
+ socketOpsCommon: socketOpsCommon{
+ ports: t.Kernel().NetlinkPorts(),
+ protocol: protocol,
+ skType: skType,
+ ep: ep,
+ connection: connection,
+ sendBufferSize: defaultSendBufferSize,
+ },
}, nil
}
// Release implements fs.FileOperations.Release.
-func (s *Socket) Release() {
+func (s *socketOpsCommon) Release() {
s.connection.Release()
s.ep.Close()
@@ -160,7 +172,7 @@ func (s *Socket) Release() {
}
// Readiness implements waiter.Waitable.Readiness.
-func (s *Socket) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
// ep holds messages to be read and thus handles EventIn readiness.
ready := s.ep.Readiness(mask)
@@ -174,18 +186,18 @@ func (s *Socket) Readiness(mask waiter.EventMask) waiter.EventMask {
}
// EventRegister implements waiter.Waitable.EventRegister.
-func (s *Socket) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
s.ep.EventRegister(e, mask)
// Writable readiness never changes, so no registration is needed.
}
// EventUnregister implements waiter.Waitable.EventUnregister.
-func (s *Socket) EventUnregister(e *waiter.Entry) {
+func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
s.ep.EventUnregister(e)
}
// Passcred implements transport.Credentialer.Passcred.
-func (s *Socket) Passcred() bool {
+func (s *socketOpsCommon) Passcred() bool {
s.mu.Lock()
passcred := s.passcred
s.mu.Unlock()
@@ -193,7 +205,7 @@ func (s *Socket) Passcred() bool {
}
// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
-func (s *Socket) ConnectedPasscred() bool {
+func (s *socketOpsCommon) ConnectedPasscred() bool {
// This socket is connected to the kernel, which doesn't need creds.
//
// This is arbitrary, as ConnectedPasscred on this type has no callers.
@@ -227,7 +239,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
// port of 0 defaults to the ThreadGroup ID.
//
// Preconditions: mu is held.
-func (s *Socket) bindPort(t *kernel.Task, port int32) *syserr.Error {
+func (s *socketOpsCommon) bindPort(t *kernel.Task, port int32) *syserr.Error {
if s.bound {
// Re-binding is only allowed if the port doesn't change.
if port != s.portID {
@@ -251,7 +263,7 @@ func (s *Socket) bindPort(t *kernel.Task, port int32) *syserr.Error {
}
// Bind implements socket.Socket.Bind.
-func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
a, err := ExtractSockAddr(sockaddr)
if err != nil {
return err
@@ -269,7 +281,7 @@ func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
// Connect implements socket.Socket.Connect.
-func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
a, err := ExtractSockAddr(sockaddr)
if err != nil {
return err
@@ -300,25 +312,25 @@ func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr
}
// Accept implements socket.Socket.Accept.
-func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
// Netlink sockets never support accept.
return 0, nil, 0, syserr.ErrNotSupported
}
// Listen implements socket.Socket.Listen.
-func (s *Socket) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
// Netlink sockets never support listen.
return syserr.ErrNotSupported
}
// Shutdown implements socket.Socket.Shutdown.
-func (s *Socket) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
// Netlink sockets never support shutdown.
return syserr.ErrNotSupported
}
// GetSockOpt implements socket.Socket.GetSockOpt.
-func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
switch name {
@@ -369,7 +381,7 @@ func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.
}
// SetSockOpt implements socket.Socket.SetSockOpt.
-func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
+func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
switch level {
case linux.SOL_SOCKET:
switch name {
@@ -466,7 +478,7 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy
}
// GetSockName implements socket.Socket.GetSockName.
-func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -478,7 +490,7 @@ func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Er
}
// GetPeerName implements socket.Socket.GetPeerName.
-func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
sa := &linux.SockAddrNetlink{
Family: linux.AF_NETLINK,
// TODO(b/68878065): Support non-kernel peers. For now the peer
@@ -489,7 +501,7 @@ func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Er
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
from := &linux.SockAddrNetlink{
Family: linux.AF_NETLINK,
PortID: 0,
@@ -590,7 +602,7 @@ func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID)
var kernelCreds = &kernelSCM{}
// sendResponse sends the response messages in ms back to userspace.
-func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error {
+func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error {
// Linux combines multiple netlink messages into a single datagram.
bufs := make([][]byte, 0, len(ms.Messages))
for _, m := range ms.Messages {
@@ -666,7 +678,7 @@ func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) {
// processMessages handles each message in buf, passing it to the protocol
// handler for final handling.
-func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error {
+func (s *socketOpsCommon) processMessages(ctx context.Context, buf []byte) *syserr.Error {
for len(buf) > 0 {
msg, rest, ok := ParseMessage(buf)
if !ok {
@@ -698,7 +710,7 @@ func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error
}
// sendMsg is the core of message send, used for SendMsg and Write.
-func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *socketOpsCommon) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
dstPort := int32(0)
if len(to) != 0 {
@@ -745,7 +757,7 @@ func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte,
}
// SendMsg implements socket.Socket.SendMsg.
-func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
return s.sendMsg(t, src, to, flags, controlMessages)
}
@@ -756,11 +768,13 @@ func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence,
}
// State implements socket.Socket.State.
-func (s *Socket) State() uint32 {
+func (s *socketOpsCommon) State() uint32 {
return s.ep.State()
}
// Type implements socket.Socket.Type.
-func (s *Socket) Type() (family int, skType linux.SockType, protocol int) {
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
return linux.AF_NETLINK, s.skType, s.protocol.Protocol()
}
+
+// LINT.ThenChange(./socket_vfs2.go)