summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/epsocket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/epsocket')
-rw-r--r--pkg/sentry/socket/epsocket/BUILD1
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go108
-rw-r--r--pkg/sentry/socket/epsocket/provider.go35
-rw-r--r--pkg/sentry/socket/epsocket/stack.go5
4 files changed, 113 insertions, 36 deletions
diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/epsocket/BUILD
index 45bb24a3f..1f014f399 100644
--- a/pkg/sentry/socket/epsocket/BUILD
+++ b/pkg/sentry/socket/epsocket/BUILD
@@ -28,7 +28,6 @@ go_library(
"//pkg/sentry/inet",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
- "//pkg/sentry/kernel/kdefs",
"//pkg/sentry/kernel/time",
"//pkg/sentry/safemem",
"//pkg/sentry/socket",
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
index 8e65e1b3f..e57aed927 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -40,7 +40,6 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
- "gvisor.dev/gvisor/pkg/sentry/kernel/kdefs"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
@@ -208,6 +207,10 @@ type commonEndpoint interface {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
GetSockOpt(interface{}) *tcpip.Error
+
+ // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
+ // transport.Endpoint.GetSockOpt.
+ GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error)
}
// SocketOperations encapsulates all the state needed to represent a network stack
@@ -250,6 +253,10 @@ type SocketOperations struct {
// timestampNS holds the timestamp to use with SIOCTSTAMP. It is only
// valid when timestampValid is true. It is protected by readMu.
timestampNS int64
+
+ // sockOptInq corresponds to TCP_INQ. It is implemented on the epsocket
+ // level, because it takes into account data from readView.
+ sockOptInq bool
}
// New creates a new endpoint socket.
@@ -286,14 +293,14 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
// GetAddress reads an sockaddr struct from the given address and converts it
// to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6
// addresses.
-func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
+func GetAddress(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, *syserr.Error) {
// Make sure we have at least 2 bytes for the address family.
if len(addr) < 2 {
return tcpip.FullAddress{}, syserr.ErrInvalidArgument
}
family := usermem.ByteOrder.Uint16(addr)
- if family != uint16(sfamily) {
+ if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) {
return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
}
@@ -318,7 +325,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
case linux.AF_INET:
var a linux.SockAddrInet
if len(addr) < sockAddrInetSize {
- return tcpip.FullAddress{}, syserr.ErrBadAddress
+ return tcpip.FullAddress{}, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
@@ -331,7 +338,7 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
case linux.AF_INET6:
var a linux.SockAddrInet6
if len(addr) < sockAddrInet6Size {
- return tcpip.FullAddress{}, syserr.ErrBadAddress
+ return tcpip.FullAddress{}, syserr.ErrInvalidArgument
}
binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
@@ -344,6 +351,9 @@ func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
}
return out, nil
+ case linux.AF_UNSPEC:
+ return tcpip.FullAddress{}, nil
+
default:
return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
}
@@ -466,7 +476,7 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
// Connect implements the linux syscall connect(2) for sockets backed by
// tpcip.Endpoint.
func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
- addr, err := GetAddress(s.family, sockaddr)
+ addr, err := GetAddress(s.family, sockaddr, false /* strict */)
if err != nil {
return err
}
@@ -499,7 +509,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- addr, err := GetAddress(s.family, sockaddr)
+ addr, err := GetAddress(s.family, sockaddr, true /* strict */)
if err != nil {
return err
}
@@ -537,7 +547,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *wait
// Accept implements the linux syscall accept(2) for sockets backed by
// tcpip.Endpoint.
-func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, interface{}, uint32, *syserr.Error) {
// Issue the accept request to get the new endpoint.
ep, wq, terr := s.Endpoint.Accept()
if terr != nil {
@@ -575,10 +585,9 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
}
}
- fdFlags := kernel.FDFlags{
+ fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
- }
- fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits())
+ })
t.Kernel().RecordSocket(ns)
@@ -633,6 +642,18 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (
}
return val, nil
}
+ if level == linux.SOL_TCP && name == linux.TCP_INQ {
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ val := int32(0)
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if s.sockOptInq {
+ val = 1
+ }
+ return val, nil
+ }
return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen)
}
@@ -866,6 +887,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return int32(v), nil
+ case linux.TCP_MAXSEG:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.MaxSegOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
case linux.TCP_KEEPIDLE:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
@@ -1035,6 +1068,15 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0
return nil
}
+ if level == linux.SOL_TCP && name == linux.TCP_INQ {
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0
+ return nil
+ }
return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
}
@@ -1218,6 +1260,14 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
v := usermem.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v)))
+ case linux.TCP_MAXSEG:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MaxSegOption(v)))
+
case linux.TCP_KEEPIDLE:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -1246,6 +1296,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
return syserr.TranslateNetstackError(err)
}
return nil
+
case linux.TCP_REPAIR_OPTIONS:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1471,7 +1522,6 @@ func emitUnimplementedEventTCP(t *kernel.Task, name int) {
linux.TCP_FASTOPEN_CONNECT,
linux.TCP_FASTOPEN_KEY,
linux.TCP_FASTOPEN_NO_COOKIE,
- linux.TCP_INQ,
linux.TCP_KEEPCNT,
linux.TCP_KEEPIDLE,
linux.TCP_KEEPINTVL,
@@ -1726,6 +1776,18 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq
return 0, err
}
+func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) {
+ if !s.sockOptInq {
+ return
+ }
+ rcvBufUsed, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if err != nil {
+ return
+ }
+ cmsg.IP.HasInq = true
+ cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed)
+}
+
// nonBlockingRead issues a non-blocking read.
//
// TODO(b/78348848): Support timestamps for stream sockets.
@@ -1745,7 +1807,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
s.readMu.Lock()
n, err := s.coalescingRead(ctx, dst, trunc)
s.readMu.Unlock()
- return n, 0, nil, 0, socket.ControlMessages{}, err
+ cmsg := s.controlMessages()
+ s.fillCmsgInq(&cmsg)
+ return n, 0, nil, 0, cmsg, err
}
s.readMu.Lock()
@@ -1758,8 +1822,8 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
if !isPacket && peek && trunc {
// MSG_TRUNC with MSG_PEEK on a TCP socket returns the
// amount that could be read.
- var rql tcpip.ReceiveQueueSizeOption
- if err := s.Endpoint.GetSockOpt(&rql); err != nil {
+ rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if err != nil {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
}
available := len(s.readView) + int(rql)
@@ -1827,7 +1891,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
n = msgLen
}
- return n, flags, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ cmsg := s.controlMessages()
+ s.fillCmsgInq(&cmsg)
+ return n, flags, addr, addrLen, cmsg, syserr.FromError(err)
}
func (s *SocketOperations) controlMessages() socket.ControlMessages {
@@ -1924,7 +1990,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
var addr *tcpip.FullAddress
if len(to) > 0 {
- addrBuf, err := GetAddress(s.family, to)
+ addrBuf, err := GetAddress(s.family, to, true /* strict */)
if err != nil {
return 0, err
}
@@ -1993,7 +2059,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
}
// Ioctl implements fs.FileOperations.Ioctl.
-func (s *SocketOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
// SIOCGSTAMP is implemented by epsocket rather than all commonEndpoint
// sockets.
// TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP.
@@ -2065,9 +2131,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
return 0, err
case linux.TIOCINQ:
- var v tcpip.ReceiveQueueSizeOption
- if err := ep.GetSockOpt(&v); err != nil {
- return 0, syserr.TranslateNetstackError(err).ToError()
+ v, terr := ep.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
}
if v > math.MaxInt32 {
diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go
index 6d2b5d038..421f93dc4 100644
--- a/pkg/sentry/socket/epsocket/provider.go
+++ b/pkg/sentry/socket/epsocket/provider.go
@@ -40,42 +40,49 @@ type provider struct {
}
// getTransportProtocol figures out transport protocol. Currently only TCP,
-// UDP, and ICMP are supported.
-func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) {
+// UDP, and ICMP are supported. The bool return value is true when this socket
+// is associated with a transport protocol. This is only false for SOCK_RAW,
+// IPPROTO_IP sockets.
+func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol int) (tcpip.TransportProtocolNumber, bool, *syserr.Error) {
switch stype {
case linux.SOCK_STREAM:
if protocol != 0 && protocol != syscall.IPPROTO_TCP {
- return 0, syserr.ErrInvalidArgument
+ return 0, true, syserr.ErrInvalidArgument
}
- return tcp.ProtocolNumber, nil
+ return tcp.ProtocolNumber, true, nil
case linux.SOCK_DGRAM:
switch protocol {
case 0, syscall.IPPROTO_UDP:
- return udp.ProtocolNumber, nil
+ return udp.ProtocolNumber, true, nil
case syscall.IPPROTO_ICMP:
- return header.ICMPv4ProtocolNumber, nil
+ return header.ICMPv4ProtocolNumber, true, nil
case syscall.IPPROTO_ICMPV6:
- return header.ICMPv6ProtocolNumber, nil
+ return header.ICMPv6ProtocolNumber, true, nil
}
case linux.SOCK_RAW:
// Raw sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(ctx)
if !creds.HasCapability(linux.CAP_NET_RAW) {
- return 0, syserr.ErrPermissionDenied
+ return 0, true, syserr.ErrPermissionDenied
}
switch protocol {
case syscall.IPPROTO_ICMP:
- return header.ICMPv4ProtocolNumber, nil
+ return header.ICMPv4ProtocolNumber, true, nil
case syscall.IPPROTO_UDP:
- return header.UDPProtocolNumber, nil
+ return header.UDPProtocolNumber, true, nil
case syscall.IPPROTO_TCP:
- return header.TCPProtocolNumber, nil
+ return header.TCPProtocolNumber, true, nil
+ // IPPROTO_RAW signifies that the raw socket isn't assigned to
+ // a transport protocol. Users will be able to write packets'
+ // IP headers and won't receive anything.
+ case syscall.IPPROTO_RAW:
+ return tcpip.TransportProtocolNumber(0), false, nil
}
}
- return 0, syserr.ErrProtocolNotSupported
+ return 0, true, syserr.ErrProtocolNotSupported
}
// Socket creates a new socket object for the AF_INET or AF_INET6 family.
@@ -93,7 +100,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
}
// Figure out the transport protocol.
- transProto, err := getTransportProtocol(t, stype, protocol)
+ transProto, associated, err := getTransportProtocol(t, stype, protocol)
if err != nil {
return nil, err
}
@@ -103,7 +110,7 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
var e *tcpip.Error
wq := &waiter.Queue{}
if stype == linux.SOCK_RAW {
- ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq)
+ ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated)
} else {
ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq)
}
diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/epsocket/stack.go
index 1627a4f68..7eef19f74 100644
--- a/pkg/sentry/socket/epsocket/stack.go
+++ b/pkg/sentry/socket/epsocket/stack.go
@@ -138,3 +138,8 @@ func (s *Stack) TCPSACKEnabled() (bool, error) {
func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError()
}
+
+// Statistics implements inet.Stack.Statistics.
+func (s *Stack) Statistics(stat interface{}, arg string) error {
+ return syserr.ErrEndpointOperation.ToError()
+}