summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket/epsocket/provider.go
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket/epsocket/provider.go')
-rw-r--r--pkg/sentry/socket/epsocket/provider.go28
1 files changed, 24 insertions, 4 deletions
diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go
index 0184d8e3e..0d9c2df24 100644
--- a/pkg/sentry/socket/epsocket/provider.go
+++ b/pkg/sentry/socket/epsocket/provider.go
@@ -18,8 +18,10 @@ import (
"syscall"
"gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
"gvisor.googlesource.com/gvisor/pkg/sentry/fs"
"gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket"
"gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.googlesource.com/gvisor/pkg/syserr"
@@ -38,9 +40,9 @@ type provider struct {
netProto tcpip.NetworkProtocolNumber
}
-// GetTransportProtocol figures out transport protocol. Currently only TCP,
+// getTransportProtocol figures out transport protocol. Currently only TCP,
// UDP, and ICMP are supported.
-func GetTransportProtocol(stype transport.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) {
+func getTransportProtocol(ctx context.Context, stype transport.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) {
switch stype {
case linux.SOCK_STREAM:
if protocol != 0 && protocol != syscall.IPPROTO_TCP {
@@ -57,6 +59,18 @@ func GetTransportProtocol(stype transport.SockType, protocol int) (tcpip.Transpo
case syscall.IPPROTO_ICMPV6:
return header.ICMPv6ProtocolNumber, nil
}
+
+ case linux.SOCK_RAW:
+ // Raw sockets require CAP_NET_RAW.
+ creds := auth.CredentialsFromContext(ctx)
+ if !creds.HasCapability(linux.CAP_NET_RAW) {
+ return 0, syserr.ErrPermissionDenied
+ }
+
+ switch protocol {
+ case syscall.IPPROTO_ICMP:
+ return header.ICMPv4ProtocolNumber, nil
+ }
}
return 0, syserr.ErrInvalidArgument
}
@@ -76,14 +90,20 @@ func (p *provider) Socket(t *kernel.Task, stype transport.SockType, protocol int
}
// Figure out the transport protocol.
- transProto, err := GetTransportProtocol(stype, protocol)
+ transProto, err := getTransportProtocol(t, stype, protocol)
if err != nil {
return nil, err
}
// Create the endpoint.
+ var ep tcpip.Endpoint
+ var e *tcpip.Error
wq := &waiter.Queue{}
- ep, e := eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+ if stype == linux.SOCK_RAW {
+ ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq)
+ } else {
+ ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+ }
if e != nil {
return nil, syserr.TranslateNetstackError(e)
}