summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/BUILD4
-rw-r--r--pkg/sentry/socket/control/BUILD5
-rw-r--r--pkg/sentry/socket/control/control.go209
-rw-r--r--pkg/sentry/socket/hostinet/BUILD7
-rw-r--r--pkg/sentry/socket/hostinet/socket.go110
-rw-r--r--pkg/sentry/socket/hostinet/stack.go116
-rw-r--r--pkg/sentry/socket/netfilter/BUILD4
-rw-r--r--pkg/sentry/socket/netlink/BUILD6
-rw-r--r--pkg/sentry/socket/netlink/port/BUILD5
-rw-r--r--pkg/sentry/socket/netlink/provider.go7
-rw-r--r--pkg/sentry/socket/netlink/route/BUILD4
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go5
-rw-r--r--pkg/sentry/socket/netlink/socket.go140
-rw-r--r--pkg/sentry/socket/netlink/uevent/BUILD17
-rw-r--r--pkg/sentry/socket/netlink/uevent/protocol.go60
-rw-r--r--pkg/sentry/socket/netstack/BUILD (renamed from pkg/sentry/socket/epsocket/BUILD)10
-rw-r--r--pkg/sentry/socket/netstack/device.go (renamed from pkg/sentry/socket/epsocket/device.go)6
-rw-r--r--pkg/sentry/socket/netstack/netstack.go (renamed from pkg/sentry/socket/epsocket/epsocket.go)412
-rw-r--r--pkg/sentry/socket/netstack/provider.go (renamed from pkg/sentry/socket/epsocket/provider.go)57
-rw-r--r--pkg/sentry/socket/netstack/save_restore.go (renamed from pkg/sentry/socket/epsocket/save_restore.go)2
-rw-r--r--pkg/sentry/socket/netstack/stack.go (renamed from pkg/sentry/socket/epsocket/stack.go)110
-rw-r--r--pkg/sentry/socket/rpcinet/BUILD10
-rw-r--r--pkg/sentry/socket/rpcinet/stack.go10
-rw-r--r--pkg/sentry/socket/rpcinet/syscall_rpc.proto1
-rw-r--r--pkg/sentry/socket/socket.go5
-rw-r--r--pkg/sentry/socket/unix/BUILD6
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD4
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go5
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go16
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go91
-rw-r--r--pkg/sentry/socket/unix/unix.go30
31 files changed, 1246 insertions, 228 deletions
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index 3300f9a6b..26176b10d 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "socket",
srcs = ["socket.go"],
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index 81dbd7309..357517ed4 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "control",
srcs = ["control.go"],
@@ -17,6 +17,7 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/auth",
+ "//pkg/sentry/socket",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usermem",
"//pkg/syserror",
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 4e95101b7..af1a4e95f 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserror"
@@ -194,15 +195,15 @@ func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([
// the available space, we must align down.
//
// align must be >= 4 and each data int32 is 4 bytes. The length of the
- // header is already aligned, so if we align to the with of the data there
+ // header is already aligned, so if we align to the width of the data there
// are two cases:
// 1. The aligned length is less than the length of the header. The
// unaligned length was also less than the length of the header, so we
// can't write anything.
// 2. The aligned length is greater than or equal to the length of the
- // header. We can write the header plus zero or more datas. We can't write
- // a partial int32, so the length of the message will be
- // min(aligned length, header + datas).
+ // header. We can write the header plus zero or more bytes of data. We can't
+ // write a partial int32, so the length of the message will be
+ // min(aligned length, header + data).
if space < linux.SizeOfControlMessageHeader {
flags |= linux.MSG_CTRUNC
return buf, flags
@@ -239,12 +240,12 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf
buf = binary.Marshal(buf, usermem.ByteOrder, data)
- // Check if we went over.
+ // If the control message data brought us over capacity, omit it.
if cap(buf) != cap(ob) {
return hdrBuf
}
- // Fix up length.
+ // Update control message length to include data.
putUint64(ob, uint64(len(buf)-len(ob)))
return alignSlice(buf, align)
@@ -320,35 +321,109 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte {
buf,
linux.SOL_TCP,
linux.TCP_INQ,
- 4,
+ t.Arch().Width(),
inq,
)
}
+// PackTOS packs an IP_TOS socket control message.
+func PackTOS(t *kernel.Task, tos int8, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IP,
+ linux.IP_TOS,
+ t.Arch().Width(),
+ tos,
+ )
+}
+
+// PackTClass packs an IPV6_TCLASS socket control message.
+func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IPV6,
+ linux.IPV6_TCLASS,
+ t.Arch().Width(),
+ tClass,
+ )
+}
+
+// PackControlMessages packs control messages into the given buffer.
+//
+// We skip control messages specific to Unix domain sockets.
+//
+// Note that some control messages may be truncated if they do not fit under
+// the capacity of buf.
+func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte {
+ if cmsgs.IP.HasTimestamp {
+ buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf)
+ }
+
+ if cmsgs.IP.HasInq {
+ // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP.
+ buf = PackInq(t, cmsgs.IP.Inq, buf)
+ }
+
+ if cmsgs.IP.HasTOS {
+ buf = PackTOS(t, cmsgs.IP.TOS, buf)
+ }
+
+ if cmsgs.IP.HasTClass {
+ buf = PackTClass(t, cmsgs.IP.TClass, buf)
+ }
+
+ return buf
+}
+
+// cmsgSpace is equivalent to CMSG_SPACE in Linux.
+func cmsgSpace(t *kernel.Task, dataLen int) int {
+ return linux.SizeOfControlMessageHeader + AlignUp(dataLen, t.Arch().Width())
+}
+
+// CmsgsSpace returns the number of bytes needed to fit the control messages
+// represented in cmsgs.
+func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
+ space := 0
+
+ if cmsgs.IP.HasTimestamp {
+ space += cmsgSpace(t, linux.SizeOfTimeval)
+ }
+
+ if cmsgs.IP.HasInq {
+ space += cmsgSpace(t, linux.SizeOfControlMessageInq)
+ }
+
+ if cmsgs.IP.HasTOS {
+ space += cmsgSpace(t, linux.SizeOfControlMessageTOS)
+ }
+
+ if cmsgs.IP.HasTClass {
+ space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
+ }
+
+ return space
+}
+
// Parse parses a raw socket control message into portable objects.
-func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.ControlMessages, error) {
+func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) {
var (
- fds linux.ControlMessageRights
- haveCreds bool
- creds linux.ControlMessageCredentials
+ cmsgs socket.ControlMessages
+ fds linux.ControlMessageRights
)
for i := 0; i < len(buf); {
if i+linux.SizeOfControlMessageHeader > len(buf) {
- return transport.ControlMessages{}, syserror.EINVAL
+ return cmsgs, syserror.EINVAL
}
var h linux.ControlMessageHeader
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h)
if h.Length < uint64(linux.SizeOfControlMessageHeader) {
- return transport.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, syserror.EINVAL
}
if h.Length > uint64(len(buf)-i) {
- return transport.ControlMessages{}, syserror.EINVAL
- }
- if h.Level != linux.SOL_SOCKET {
- return transport.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, syserror.EINVAL
}
i += linux.SizeOfControlMessageHeader
@@ -358,59 +433,79 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.
// sizeof(long) in CMSG_ALIGN.
width := t.Arch().Width()
- switch h.Type {
- case linux.SCM_RIGHTS:
- rightsSize := AlignDown(length, linux.SizeOfControlMessageRight)
- numRights := rightsSize / linux.SizeOfControlMessageRight
-
- if len(fds)+numRights > linux.SCM_MAX_FD {
- return transport.ControlMessages{}, syserror.EINVAL
+ switch h.Level {
+ case linux.SOL_SOCKET:
+ switch h.Type {
+ case linux.SCM_RIGHTS:
+ rightsSize := AlignDown(length, linux.SizeOfControlMessageRight)
+ numRights := rightsSize / linux.SizeOfControlMessageRight
+
+ if len(fds)+numRights > linux.SCM_MAX_FD {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight {
+ fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
+ }
+
+ i += AlignUp(length, width)
+
+ case linux.SCM_CREDENTIALS:
+ if length < linux.SizeOfControlMessageCredentials {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ var creds linux.ControlMessageCredentials
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds)
+ scmCreds, err := NewSCMCredentials(t, creds)
+ if err != nil {
+ return socket.ControlMessages{}, err
+ }
+ cmsgs.Unix.Credentials = scmCreds
+ i += AlignUp(length, width)
+
+ default:
+ // Unknown message type.
+ return socket.ControlMessages{}, syserror.EINVAL
}
-
- for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight {
- fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
+ case linux.SOL_IP:
+ switch h.Type {
+ case linux.IP_TOS:
+ cmsgs.IP.HasTOS = true
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS)
+ i += AlignUp(length, width)
+
+ default:
+ return socket.ControlMessages{}, syserror.EINVAL
}
-
- i += AlignUp(length, width)
-
- case linux.SCM_CREDENTIALS:
- if length < linux.SizeOfControlMessageCredentials {
- return transport.ControlMessages{}, syserror.EINVAL
+ case linux.SOL_IPV6:
+ switch h.Type {
+ case linux.IPV6_TCLASS:
+ cmsgs.IP.HasTClass = true
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
+ i += AlignUp(length, width)
+
+ default:
+ return socket.ControlMessages{}, syserror.EINVAL
}
-
- binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds)
- haveCreds = true
- i += AlignUp(length, width)
-
default:
- // Unknown message type.
- return transport.ControlMessages{}, syserror.EINVAL
+ return socket.ControlMessages{}, syserror.EINVAL
}
}
- var credentials SCMCredentials
- if haveCreds {
- var err error
- if credentials, err = NewSCMCredentials(t, creds); err != nil {
- return transport.ControlMessages{}, err
- }
- } else {
- credentials = makeCreds(t, socketOrEndpoint)
+ if cmsgs.Unix.Credentials == nil {
+ cmsgs.Unix.Credentials = makeCreds(t, socketOrEndpoint)
}
- var rights SCMRights
if len(fds) > 0 {
- var err error
- if rights, err = NewSCMRights(t, fds); err != nil {
- return transport.ControlMessages{}, err
+ rights, err := NewSCMRights(t, fds)
+ if err != nil {
+ return socket.ControlMessages{}, err
}
+ cmsgs.Unix.Rights = rights
}
- if credentials == nil && rights == nil {
- return transport.ControlMessages{}, nil
- }
-
- return transport.ControlMessages{Credentials: credentials, Rights: rights}, nil
+ return cmsgs, nil
}
func makeCreds(t *kernel.Task, socketOrEndpoint interface{}) SCMCredentials {
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index a951f1bb0..4c44c7c0f 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "hostinet",
srcs = [
@@ -29,9 +29,12 @@ go_library(
"//pkg/sentry/kernel/time",
"//pkg/sentry/safemem",
"//pkg/sentry/socket",
+ "//pkg/sentry/socket/control",
"//pkg/sentry/usermem",
"//pkg/syserr",
"//pkg/syserror",
+ "//pkg/tcpip/stack",
"//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 92beb1bcf..c957b0f1d 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -18,6 +18,7 @@ import (
"fmt"
"syscall"
+ "golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/fdnotifier"
@@ -29,6 +30,7 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
@@ -41,6 +43,10 @@ const (
// sizeofSockaddr is the size in bytes of the largest sockaddr type
// supported by this package.
sizeofSockaddr = syscall.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in)
+
+ // maxControlLen is the maximum size of a control message buffer used in a
+ // recvmsg or sendmsg syscall.
+ maxControlLen = 1024
)
// socketOperations implements fs.FileOperations and socket.Socket for a socket
@@ -281,26 +287,32 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt
// Whitelist options and constrain option length.
var optlen int
switch level {
- case syscall.SOL_IPV6:
+ case linux.SOL_IP:
+ switch name {
+ case linux.IP_TOS, linux.IP_RECVTOS:
+ optlen = sizeofInt32
+ }
+ case linux.SOL_IPV6:
switch name {
- case syscall.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
optlen = sizeofInt32
}
- case syscall.SOL_SOCKET:
+ case linux.SOL_SOCKET:
switch name {
- case syscall.SO_ERROR, syscall.SO_KEEPALIVE, syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR:
+ case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
optlen = sizeofInt32
- case syscall.SO_LINGER:
+ case linux.SO_LINGER:
optlen = syscall.SizeofLinger
}
- case syscall.SOL_TCP:
+ case linux.SOL_TCP:
switch name {
- case syscall.TCP_NODELAY:
+ case linux.TCP_NODELAY:
optlen = sizeofInt32
- case syscall.TCP_INFO:
+ case linux.TCP_INFO:
optlen = int(linux.SizeOfTCPInfo)
}
}
+
if optlen == 0 {
return nil, syserr.ErrProtocolNotAvailable // ENOPROTOOPT
}
@@ -320,19 +332,24 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
// Whitelist options and constrain option length.
var optlen int
switch level {
- case syscall.SOL_IPV6:
+ case linux.SOL_IP:
+ switch name {
+ case linux.IP_TOS, linux.IP_RECVTOS:
+ optlen = sizeofInt32
+ }
+ case linux.SOL_IPV6:
switch name {
- case syscall.IPV6_V6ONLY:
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
optlen = sizeofInt32
}
- case syscall.SOL_SOCKET:
+ case linux.SOL_SOCKET:
switch name {
- case syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR:
+ case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
optlen = sizeofInt32
}
- case syscall.SOL_TCP:
+ case linux.SOL_TCP:
switch name {
- case syscall.TCP_NODELAY:
+ case linux.TCP_NODELAY:
optlen = sizeofInt32
}
}
@@ -354,11 +371,11 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
}
// RecvMsg implements socket.Socket.RecvMsg.
-func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
// Whitelist flags.
//
// FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary
- // messages that netstack/tcpip/transport/unix doesn't understand. Kill the
+ // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the
// Socket interface's dependence on netstack.
if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 {
return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
@@ -370,6 +387,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
senderAddrBuf = make([]byte, sizeofSockaddr)
}
+ var controlBuf []byte
var msgFlags int
recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
@@ -384,11 +402,6 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
// We always do a non-blocking recv*().
sysflags := flags | syscall.MSG_DONTWAIT
- if dsts.NumBlocks() == 1 {
- // Skip allocating []syscall.Iovec.
- return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddrBuf)
- }
-
iovs := iovecsFromBlockSeq(dsts)
msg := syscall.Msghdr{
Iov: &iovs[0],
@@ -398,12 +411,21 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
msg.Name = &senderAddrBuf[0]
msg.Namelen = uint32(len(senderAddrBuf))
}
+ if controlLen > 0 {
+ if controlLen > maxControlLen {
+ controlLen = maxControlLen
+ }
+ controlBuf = make([]byte, controlLen)
+ msg.Control = &controlBuf[0]
+ msg.Controllen = controlLen
+ }
n, err := recvmsg(s.fd, &msg, sysflags)
if err != nil {
return 0, err
}
senderAddrBuf = senderAddrBuf[:msg.Namelen]
msgFlags = int(msg.Flags)
+ controlLen = uint64(msg.Controllen)
return n, nil
})
@@ -429,14 +451,38 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
n, err = dst.CopyOutFrom(t, recvmsgToBlocks)
}
}
-
- // We don't allow control messages.
- msgFlags &^= linux.MSG_CTRUNC
+ if err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
if senderRequested {
senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf)
}
- return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), socket.ControlMessages{}, syserr.FromError(err)
+
+ unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen])
+ if err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+
+ controlMessages := socket.ControlMessages{}
+ for _, unixCmsg := range unixControlMessages {
+ switch unixCmsg.Header.Level {
+ case syscall.SOL_IP:
+ switch unixCmsg.Header.Type {
+ case syscall.IP_TOS:
+ controlMessages.IP.HasTOS = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
+ }
+ case syscall.SOL_IPV6:
+ switch unixCmsg.Header.Type {
+ case syscall.IPV6_TCLASS:
+ controlMessages.IP.HasTClass = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass)
+ }
+ }
+ }
+
+ return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil
}
// SendMsg implements socket.Socket.SendMsg.
@@ -446,6 +492,14 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return 0, syserr.ErrInvalidArgument
}
+ space := uint64(control.CmsgsSpace(t, controlMessages))
+ if space > maxControlLen {
+ space = maxControlLen
+ }
+ controlBuf := make([]byte, 0, space)
+ // PackControlMessages will append up to space bytes to controlBuf.
+ controlBuf = control.PackControlMessages(t, controlMessages, controlBuf)
+
sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
// Refuse to do anything if any part of src.Addrs was unusable.
if uint64(src.NumBytes()) != srcs.NumBytes() {
@@ -458,7 +512,7 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
// We always do a non-blocking send*().
sysflags := flags | syscall.MSG_DONTWAIT
- if srcs.NumBlocks() == 1 {
+ if srcs.NumBlocks() == 1 && len(controlBuf) == 0 {
// Skip allocating []syscall.Iovec.
src := srcs.Head()
n, _, errno := syscall.Syscall6(syscall.SYS_SENDTO, uintptr(s.fd), src.Addr(), uintptr(src.Len()), uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to)))
@@ -477,6 +531,10 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
msg.Name = &to[0]
msg.Namelen = uint32(len(to))
}
+ if len(controlBuf) != 0 {
+ msg.Control = &controlBuf[0]
+ msg.Controllen = uint64(len(controlBuf))
+ }
return sendmsg(s.fd, &msg, sysflags)
})
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index 3a4fdec47..e67b46c9e 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -16,8 +16,11 @@ package hostinet
import (
"fmt"
+ "io"
"io/ioutil"
"os"
+ "reflect"
+ "strconv"
"strings"
"syscall"
@@ -26,7 +29,9 @@ import (
"gvisor.dev/gvisor/pkg/sentry/context"
"gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/usermem"
+ "gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
var defaultRecvBufSize = inet.TCPBufferSize{
@@ -51,6 +56,8 @@ type Stack struct {
tcpRecvBufSize inet.TCPBufferSize
tcpSendBufSize inet.TCPBufferSize
tcpSACKEnabled bool
+ netDevFile *os.File
+ netSNMPFile *os.File
}
// NewStack returns an empty Stack containing no configuration.
@@ -98,6 +105,18 @@ func (s *Stack) Configure() error {
log.Warningf("Failed to read if TCP SACK if enabled, setting to true")
}
+ if f, err := os.Open("/proc/net/dev"); err != nil {
+ log.Warningf("Failed to open /proc/net/dev: %v", err)
+ } else {
+ s.netDevFile = f
+ }
+
+ if f, err := os.Open("/proc/net/snmp"); err != nil {
+ log.Warningf("Failed to open /proc/net/snmp: %v", err)
+ } else {
+ s.netSNMPFile = f
+ }
+
return nil
}
@@ -326,9 +345,95 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
return syserror.EACCES
}
+// getLine reads one line from proc file, with specified prefix.
+// The last argument, withHeader, specifies if it contains line header.
+func getLine(f *os.File, prefix string, withHeader bool) string {
+ data := make([]byte, 4096)
+
+ if _, err := f.Seek(0, 0); err != nil {
+ return ""
+ }
+
+ if _, err := io.ReadFull(f, data); err != io.ErrUnexpectedEOF {
+ return ""
+ }
+
+ prefix = prefix + ":"
+ lines := strings.Split(string(data), "\n")
+ for _, l := range lines {
+ l = strings.TrimSpace(l)
+ if strings.HasPrefix(l, prefix) {
+ if withHeader {
+ withHeader = false
+ continue
+ }
+ return l
+ }
+ }
+ return ""
+}
+
+func toSlice(i interface{}) []uint64 {
+ v := reflect.Indirect(reflect.ValueOf(i))
+ return v.Slice(0, v.Len()).Interface().([]uint64)
+}
+
// Statistics implements inet.Stack.Statistics.
func (s *Stack) Statistics(stat interface{}, arg string) error {
- return syserror.EOPNOTSUPP
+ var (
+ snmpTCP bool
+ rawLine string
+ sliceStat []uint64
+ )
+
+ switch stat.(type) {
+ case *inet.StatDev:
+ if s.netDevFile == nil {
+ return fmt.Errorf("/proc/net/dev is not opened for hostinet")
+ }
+ rawLine = getLine(s.netDevFile, arg, false /* with no header */)
+ case *inet.StatSNMPIP, *inet.StatSNMPICMP, *inet.StatSNMPICMPMSG, *inet.StatSNMPTCP, *inet.StatSNMPUDP, *inet.StatSNMPUDPLite:
+ if s.netSNMPFile == nil {
+ return fmt.Errorf("/proc/net/snmp is not opened for hostinet")
+ }
+ rawLine = getLine(s.netSNMPFile, arg, true)
+ default:
+ return syserr.ErrEndpointOperation.ToError()
+ }
+
+ if rawLine == "" {
+ return fmt.Errorf("Failed to get raw line")
+ }
+
+ parts := strings.SplitN(rawLine, ":", 2)
+ if len(parts) != 2 {
+ return fmt.Errorf("Failed to get prefix from: %q", rawLine)
+ }
+
+ sliceStat = toSlice(stat)
+ fields := strings.Fields(strings.TrimSpace(parts[1]))
+ if len(fields) != len(sliceStat) {
+ return fmt.Errorf("Failed to parse fields: %q", rawLine)
+ }
+ if _, ok := stat.(*inet.StatSNMPTCP); ok {
+ snmpTCP = true
+ }
+ for i := 0; i < len(sliceStat); i++ {
+ var err error
+ if snmpTCP && i == 3 {
+ var tmp int64
+ // MaxConn field is signed, RFC 2012.
+ tmp, err = strconv.ParseInt(fields[i], 10, 64)
+ sliceStat[i] = uint64(tmp) // Convert back to int before use.
+ } else {
+ sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64)
+ }
+ if err != nil {
+ return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err)
+ }
+ }
+
+ return nil
}
// RouteTable implements inet.Stack.RouteTable.
@@ -338,3 +443,12 @@ func (s *Stack) RouteTable() []inet.Route {
// Resume implements inet.Stack.Resume.
func (s *Stack) Resume() {}
+
+// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil }
+
+// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil }
+
+// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index 354a0d6ee..5eb06bbf4 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "netfilter",
srcs = [
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 45ebb2a0e..79589e3c8 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "netlink",
srcs = [
@@ -20,7 +20,9 @@ go_library(
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
"//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
"//pkg/sentry/kernel/time",
+ "//pkg/sentry/safemem",
"//pkg/sentry/socket",
"//pkg/sentry/socket/netlink/port",
"//pkg/sentry/socket/unix",
diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD
index 9e2e12799..463544c1a 100644
--- a/pkg/sentry/socket/netlink/port/BUILD
+++ b/pkg/sentry/socket/netlink/port/BUILD
@@ -1,6 +1,7 @@
-package(licenses = ["notice"])
+load("@io_bazel_rules_go//go:def.bzl", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_library")
-load("//tools/go_stateify:defs.bzl", "go_library", "go_test")
+package(licenses = ["notice"])
go_library(
name = "port",
diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go
index 689cad997..be005df24 100644
--- a/pkg/sentry/socket/netlink/provider.go
+++ b/pkg/sentry/socket/netlink/provider.go
@@ -30,6 +30,13 @@ type Protocol interface {
// Protocol returns the Linux netlink protocol value.
Protocol() int
+ // CanSend returns true if this protocol may ever send messages.
+ //
+ // TODO(gvisor.dev/issue/1119): This is a workaround to allow
+ // advertising support for otherwise unimplemented features on sockets
+ // that will never send messages, thus making those features no-ops.
+ CanSend() bool
+
// ProcessMessage processes a single message from userspace.
//
// If err == nil, any messages added to ms will be sent back to the
diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD
index 5dc8533ec..1d4912753 100644
--- a/pkg/sentry/socket/netlink/route/BUILD
+++ b/pkg/sentry/socket/netlink/route/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "route",
srcs = ["protocol.go"],
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index cc70ac237..6b4a0ecf4 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -61,6 +61,11 @@ func (p *Protocol) Protocol() int {
return linux.NETLINK_ROUTE
}
+// CanSend implements netlink.Protocol.CanSend.
+func (p *Protocol) CanSend() bool {
+ return true
+}
+
// dumpLinks handles RTM_GETLINK + NLM_F_DUMP requests.
func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
// NLM_F_DUMP + RTM_GETLINK messages are supposed to include an
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index d0aab293d..4a1b87a9a 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -27,7 +27,9 @@ import (
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
"gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/safemem"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/netlink/port"
"gvisor.dev/gvisor/pkg/sentry/socket/unix"
@@ -52,6 +54,8 @@ const (
maxSendBufferSize = 4 << 20 // 4MB
)
+var errNoFilter = syserr.New("no filter attached", linux.ENOENT)
+
// netlinkSocketDevice is the netlink socket virtual device.
var netlinkSocketDevice = device.NewAnonDevice()
@@ -60,7 +64,7 @@ var netlinkSocketDevice = device.NewAnonDevice()
// This implementation only supports userspace sending and receiving messages
// to/from the kernel.
//
-// Socket implements socket.Socket.
+// Socket implements socket.Socket and transport.Credentialer.
//
// +stateify savable
type Socket struct {
@@ -103,9 +107,19 @@ type Socket struct {
// sendBufferSize is the send buffer "size". We don't actually have a
// fixed buffer but only consume this many bytes.
sendBufferSize uint32
+
+ // passcred indicates if this socket wants SCM credentials.
+ passcred bool
+
+ // filter indicates that this socket has a BPF filter "installed".
+ //
+ // TODO(gvisor.dev/issue/1119): We don't actually support filtering,
+ // this is just bookkeeping for tracking add/remove.
+ filter bool
}
var _ socket.Socket = (*Socket)(nil)
+var _ transport.Credentialer = (*Socket)(nil)
// NewSocket creates a new Socket.
func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) {
@@ -171,6 +185,22 @@ func (s *Socket) EventUnregister(e *waiter.Entry) {
s.ep.EventUnregister(e)
}
+// Passcred implements transport.Credentialer.Passcred.
+func (s *Socket) Passcred() bool {
+ s.mu.Lock()
+ passcred := s.passcred
+ s.mu.Unlock()
+ return passcred
+}
+
+// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
+func (s *Socket) ConnectedPasscred() bool {
+ // This socket is connected to the kernel, which doesn't need creds.
+ //
+ // This is arbitrary, as ConnectedPasscred on this type has no callers.
+ return false
+}
+
// Ioctl implements fs.FileOperations.Ioctl.
func (*Socket) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArguments) (uintptr, error) {
// TODO(b/68878065): no ioctls supported.
@@ -308,9 +338,20 @@ func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.
// We don't have limit on receiving size.
return int32(math.MaxInt32), nil
+ case linux.SO_PASSCRED:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ var passcred int32
+ if s.Passcred() {
+ passcred = 1
+ }
+ return passcred, nil
+
default:
socket.GetSockOptEmitUnimplementedEvent(t, name)
}
+
case linux.SOL_NETLINK:
switch name {
case linux.NETLINK_BROADCAST_ERROR,
@@ -347,6 +388,7 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy
s.sendBufferSize = size
s.mu.Unlock()
return nil
+
case linux.SO_RCVBUF:
if len(opt) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -354,6 +396,52 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy
// We don't have limit on receiving size. So just accept anything as
// valid for compatibility.
return nil
+
+ case linux.SO_PASSCRED:
+ if len(opt) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ passcred := usermem.ByteOrder.Uint32(opt)
+
+ s.mu.Lock()
+ s.passcred = passcred != 0
+ s.mu.Unlock()
+ return nil
+
+ case linux.SO_ATTACH_FILTER:
+ // TODO(gvisor.dev/issue/1119): We don't actually
+ // support filtering. If this socket can't ever send
+ // messages, then there is nothing to filter and we can
+ // advertise support. Otherwise, be conservative and
+ // return an error.
+ if s.protocol.CanSend() {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ return syserr.ErrProtocolNotAvailable
+ }
+
+ s.mu.Lock()
+ s.filter = true
+ s.mu.Unlock()
+ return nil
+
+ case linux.SO_DETACH_FILTER:
+ // TODO(gvisor.dev/issue/1119): See above.
+ if s.protocol.CanSend() {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ return syserr.ErrProtocolNotAvailable
+ }
+
+ s.mu.Lock()
+ filter := s.filter
+ s.filter = false
+ s.mu.Unlock()
+
+ if !filter {
+ return errNoFilter
+ }
+
+ return nil
+
default:
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
@@ -416,6 +504,24 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have
Peek: flags&linux.MSG_PEEK != 0,
}
+ // If MSG_TRUNC is set with a zero byte destination then we still need
+ // to read the message and discard it, or in the case where MSG_PEEK is
+ // set, leave it be. In both cases the full message length must be
+ // returned. However, the memory manager for the destination will not read
+ // the endpoint if the destination is zero length.
+ //
+ // In order for the endpoint to be read when the destination size is zero,
+ // we must cause a read of the endpoint by using a separate fake zero
+ // length block sequence and calling the EndpointReader directly.
+ if trunc && dst.Addrs.NumBytes() == 0 {
+ // Perform a read to a zero byte block sequence. We can ignore the
+ // original destination since it was zero bytes. The length returned by
+ // ReadToBlocks is ignored and we return the full message length to comply
+ // with MSG_TRUNC.
+ _, err := r.ReadToBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(make([]byte, 0))))
+ return int(r.MsgSize), linux.MSG_TRUNC, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
+ }
+
if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
var mflags int
if n < int64(r.MsgSize) {
@@ -464,6 +570,26 @@ func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _
})
}
+// kernelSCM implements control.SCMCredentials with credentials that represent
+// the kernel itself rather than a Task.
+//
+// +stateify savable
+type kernelSCM struct{}
+
+// Equals implements transport.CredentialsControlMessage.Equals.
+func (kernelSCM) Equals(oc transport.CredentialsControlMessage) bool {
+ _, ok := oc.(kernelSCM)
+ return ok
+}
+
+// Credentials implements control.SCMCredentials.Credentials.
+func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) {
+ return 0, auth.RootUID, auth.RootGID
+}
+
+// kernelCreds is the concrete version of kernelSCM used in all creds.
+var kernelCreds = &kernelSCM{}
+
// sendResponse sends the response messages in ms back to userspace.
func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error {
// Linux combines multiple netlink messages into a single datagram.
@@ -472,10 +598,15 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error
bufs = append(bufs, m.Finalize())
}
+ // All messages are from the kernel.
+ cms := transport.ControlMessages{
+ Credentials: kernelCreds,
+ }
+
if len(bufs) > 0 {
// RecvMsg never receives the address, so we don't need to send
// one.
- _, notify, err := s.connection.Send(bufs, transport.ControlMessages{}, tcpip.FullAddress{})
+ _, notify, err := s.connection.Send(bufs, cms, tcpip.FullAddress{})
// If the buffer is full, we simply drop messages, just like
// Linux.
if err != nil && err != syserr.ErrWouldBlock {
@@ -499,7 +630,10 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error
PortID: uint32(ms.PortID),
})
- _, notify, err := s.connection.Send([][]byte{m.Finalize()}, transport.ControlMessages{}, tcpip.FullAddress{})
+ // Add the dump_done_errno payload.
+ m.Put(int64(0))
+
+ _, notify, err := s.connection.Send([][]byte{m.Finalize()}, cms, tcpip.FullAddress{})
if err != nil && err != syserr.ErrWouldBlock {
return err
}
diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD
new file mode 100644
index 000000000..0777f3baf
--- /dev/null
+++ b/pkg/sentry/socket/netlink/uevent/BUILD
@@ -0,0 +1,17 @@
+load("//tools/go_stateify:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "uevent",
+ srcs = ["protocol.go"],
+ importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/uevent",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/context",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/socket/netlink",
+ "//pkg/syserr",
+ ],
+)
diff --git a/pkg/sentry/socket/netlink/uevent/protocol.go b/pkg/sentry/socket/netlink/uevent/protocol.go
new file mode 100644
index 000000000..b5d7808d7
--- /dev/null
+++ b/pkg/sentry/socket/netlink/uevent/protocol.go
@@ -0,0 +1,60 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package uevent provides a NETLINK_KOBJECT_UEVENT socket protocol.
+//
+// NETLINK_KOBJECT_UEVENT sockets send udev-style device events. gVisor does
+// not support any device events, so these sockets never send any messages.
+package uevent
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netlink"
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+// Protocol implements netlink.Protocol.
+//
+// +stateify savable
+type Protocol struct{}
+
+var _ netlink.Protocol = (*Protocol)(nil)
+
+// NewProtocol creates a NETLINK_KOBJECT_UEVENT netlink.Protocol.
+func NewProtocol(t *kernel.Task) (netlink.Protocol, *syserr.Error) {
+ return &Protocol{}, nil
+}
+
+// Protocol implements netlink.Protocol.Protocol.
+func (p *Protocol) Protocol() int {
+ return linux.NETLINK_KOBJECT_UEVENT
+}
+
+// CanSend implements netlink.Protocol.CanSend.
+func (p *Protocol) CanSend() bool {
+ return false
+}
+
+// ProcessMessage implements netlink.Protocol.ProcessMessage.
+func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+ // Silently ignore all messages.
+ return nil
+}
+
+// init registers the NETLINK_KOBJECT_UEVENT provider.
+func init() {
+ netlink.RegisterProvider(linux.NETLINK_KOBJECT_UEVENT, NewProtocol)
+}
diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/netstack/BUILD
index e927821e1..e414d8055 100644
--- a/pkg/sentry/socket/epsocket/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -1,17 +1,17 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
- name = "epsocket",
+ name = "netstack",
srcs = [
"device.go",
- "epsocket.go",
+ "netstack.go",
"provider.go",
"save_restore.go",
"stack.go",
],
- importpath = "gvisor.dev/gvisor/pkg/sentry/socket/epsocket",
+ importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netstack",
visibility = [
"//pkg/sentry:internal",
],
diff --git a/pkg/sentry/socket/epsocket/device.go b/pkg/sentry/socket/netstack/device.go
index 85484d5b1..fbeb89fb8 100644
--- a/pkg/sentry/socket/epsocket/device.go
+++ b/pkg/sentry/socket/netstack/device.go
@@ -12,9 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package epsocket
+package netstack
import "gvisor.dev/gvisor/pkg/sentry/device"
-// epsocketDevice is the endpoint socket virtual device.
-var epsocketDevice = device.NewAnonDevice()
+// netstackDevice is the endpoint socket virtual device.
+var netstackDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/netstack/netstack.go
index 0e37ce61b..140851c17 100644
--- a/pkg/sentry/socket/epsocket/epsocket.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Package epsocket provides an implementation of the socket.Socket interface
+// Package netstack provides an implementation of the socket.Socket interface
// that is backed by a tcpip.Endpoint.
//
// It does not depend on any particular endpoint implementation, and thus can
@@ -22,10 +22,11 @@
// Lock ordering: netstack => mm: ioSequencePayload copies user memory inside
// tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during
// this operation.
-package epsocket
+package netstack
import (
"bytes"
+ "io"
"math"
"reflect"
"sync"
@@ -52,6 +53,7 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
@@ -136,15 +138,21 @@ var Metrics = tcpip.Stats{
},
},
IP: tcpip.IPStats{
- PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."),
- InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."),
- PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."),
- PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."),
- OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."),
+ PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."),
+ InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."),
+ PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."),
+ PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."),
+ OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."),
+ MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."),
+ MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."),
},
TCP: tcpip.TCPStats{
ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."),
PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."),
+ CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."),
+ EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"),
+ EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "number of times established TCP connections made a transition to CLOSED state."),
+ EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."),
ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."),
ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."),
ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."),
@@ -154,6 +162,7 @@ var Metrics = tcpip.Stats{
ValidSegmentsReceived: mustCreateMetric("/netstack/tcp/valid_segments_received", "Number of TCP segments received that the transport layer successfully parsed."),
InvalidSegmentsReceived: mustCreateMetric("/netstack/tcp/invalid_segments_received", "Number of TCP segments received that the transport layer could not parse."),
SegmentsSent: mustCreateMetric("/netstack/tcp/segments_sent", "Number of TCP segments sent."),
+ SegmentSendErrors: mustCreateMetric("/netstack/tcp/segment_send_errors", "Number of TCP segments failed to be sent."),
ResetsSent: mustCreateMetric("/netstack/tcp/resets_sent", "Number of TCP resets sent."),
ResetsReceived: mustCreateMetric("/netstack/tcp/resets_received", "Number of TCP resets received."),
Retransmits: mustCreateMetric("/netstack/tcp/retransmits", "Number of TCP segments retransmitted."),
@@ -169,13 +178,18 @@ var Metrics = tcpip.Stats{
UnknownPortErrors: mustCreateMetric("/netstack/udp/unknown_port_errors", "Number of incoming UDP datagrams dropped because they did not have a known destination port."),
ReceiveBufferErrors: mustCreateMetric("/netstack/udp/receive_buffer_errors", "Number of incoming UDP datagrams dropped due to the receiving buffer being in an invalid state."),
MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."),
- PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent via sendUDP."),
+ PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."),
+ PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."),
},
}
+// DefaultTTL is linux's default TTL. All network protocols in all stacks used
+// with this package must have this value set as their default TTL.
+const DefaultTTL = 64
+
const sizeOfInt32 int = 4
-var errStackType = syserr.New("expected but did not receive an epsocket.Stack", linux.EINVAL)
+var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL)
// ntohs converts a 16-bit number from network byte order to host byte order. It
// assumes that the host is little endian.
@@ -208,6 +222,10 @@ type commonEndpoint interface {
// transport.Endpoint.SetSockOpt.
SetSockOpt(interface{}) *tcpip.Error
+ // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
+ // transport.Endpoint.SetSockOptInt.
+ SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
GetSockOpt(interface{}) *tcpip.Error
@@ -227,7 +245,6 @@ type SocketOperations struct {
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileNoFsync `state:"nosave"`
fsutil.FileNoMMap `state:"nosave"`
- fsutil.FileNoSplice `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
socket.SendReceiveTimeout
*waiter.Queue
@@ -258,20 +275,20 @@ type SocketOperations struct {
// valid when timestampValid is true. It is protected by readMu.
timestampNS int64
- // sockOptInq corresponds to TCP_INQ. It is implemented on the epsocket
- // level, because it takes into account data from readView.
+ // sockOptInq corresponds to TCP_INQ. It is implemented at this level
+ // because it takes into account data from readView.
sockOptInq bool
}
// New creates a new endpoint socket.
func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOpt(tcpip.DelayOption(1)); err != nil {
+ if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
}
- dirent := socket.NewDirent(t, epsocketDevice)
+ dirent := socket.NewDirent(t, netstackDevice)
defer dirent.DecRef()
return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, &SocketOperations{
Queue: queue,
@@ -284,6 +301,7 @@ func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue
var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{}))
var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{}))
+var sockAddrLinkSize = int(binary.Size(linux.SockAddrLink{}))
// bytesToIPAddress converts an IPv4 or IPv6 address from the user to the
// netstack representation taking any addresses into account.
@@ -295,12 +313,12 @@ func bytesToIPAddress(addr []byte) tcpip.Address {
}
// AddressAndFamily reads an sockaddr struct from the given address and
-// converts it to the FullAddress format. It supports AF_UNIX, AF_INET and
-// AF_INET6 addresses.
+// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
+// AF_INET6, and AF_PACKET addresses.
//
// strict indicates whether addresses with the AF_UNSPEC family are accepted of not.
//
-// AddressAndFamily returns an address, its family.
+// AddressAndFamily returns an address and its family.
func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) {
// Make sure we have at least 2 bytes for the address family.
if len(addr) < 2 {
@@ -308,7 +326,7 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress,
}
family := usermem.ByteOrder.Uint16(addr)
- if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) {
+ if family != uint16(sfamily) && (strict || family != linux.AF_UNSPEC) {
return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported
}
@@ -359,6 +377,22 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress,
}
return out, family, nil
+ case linux.AF_PACKET:
+ var a linux.SockAddrLink
+ if len(addr) < sockAddrLinkSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+ if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+
+ // TODO(b/129292371): Return protocol too.
+ return tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }, family, nil
+
case linux.AF_UNSPEC:
return tcpip.FullAddress{}, family, nil
@@ -412,17 +446,60 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
return int64(n), nil
}
-// ioSequencePayload implements tcpip.Payload. It copies user memory bytes on demand
-// based on the requested size.
+// WriteTo implements fs.FileOperations.WriteTo.
+func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
+ s.readMu.Lock()
+
+ // Copy as much data as possible.
+ done := int64(0)
+ for count > 0 {
+ // This may return a blocking error.
+ if err := s.fetchReadView(); err != nil {
+ s.readMu.Unlock()
+ return done, err.ToError()
+ }
+
+ // Write to the underlying file.
+ n, err := dst.Write(s.readView)
+ done += int64(n)
+ count -= int64(n)
+ if dup {
+ // That's all we support for dup. This is generally
+ // supported by any Linux system calls, but the
+ // expectation is that now a caller will call read to
+ // actually remove these bytes from the socket.
+ break
+ }
+
+ // Drop that part of the view.
+ s.readView.TrimFront(n)
+ if err != nil {
+ s.readMu.Unlock()
+ return done, err
+ }
+ }
+
+ s.readMu.Unlock()
+ return done, nil
+}
+
+// ioSequencePayload implements tcpip.Payload.
+//
+// t copies user memory bytes on demand based on the requested size.
type ioSequencePayload struct {
ctx context.Context
src usermem.IOSequence
}
-// Get implements tcpip.Payload.
-func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) {
- if size > i.Size() {
- size = i.Size()
+// FullPayload implements tcpip.Payloader.FullPayload
+func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) {
+ return i.Payload(int(i.src.NumBytes()))
+}
+
+// Payload implements tcpip.Payloader.Payload.
+func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) {
+ if max := int(i.src.NumBytes()); size > max {
+ size = max
}
v := buffer.NewView(size)
if _, err := i.src.CopyIn(i.ctx, v); err != nil {
@@ -431,11 +508,6 @@ func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) {
return v, nil
}
-// Size implements tcpip.Payload.
-func (i *ioSequencePayload) Size() int {
- return int(i.src.NumBytes())
-}
-
// DropFirst drops the first n bytes from underlying src.
func (i *ioSequencePayload) DropFirst(n int) {
i.src = i.src.DropFirst(int(n))
@@ -469,6 +541,78 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
return int64(n), nil
}
+// readerPayload implements tcpip.Payloader.
+//
+// It allocates a view and reads from a reader on-demand, based on available
+// capacity in the endpoint.
+type readerPayload struct {
+ ctx context.Context
+ r io.Reader
+ count int64
+ err error
+}
+
+// FullPayload implements tcpip.Payloader.FullPayload.
+func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) {
+ return r.Payload(int(r.count))
+}
+
+// Payload implements tcpip.Payloader.Payload.
+func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) {
+ if size > int(r.count) {
+ size = int(r.count)
+ }
+ v := buffer.NewView(size)
+ n, err := r.r.Read(v)
+ if n > 0 {
+ // We ignore the error here. It may re-occur on subsequent
+ // reads, but for now we can enqueue some amount of data.
+ r.count -= int64(n)
+ return v[:n], nil
+ }
+ if err == syserror.ErrWouldBlock {
+ return nil, tcpip.ErrWouldBlock
+ } else if err != nil {
+ r.err = err // Save for propation.
+ return nil, tcpip.ErrBadAddress
+ }
+
+ // There is no data and no error. Return an error, which will propagate
+ // r.err, which will be nil. This is the desired result: (0, nil).
+ return nil, tcpip.ErrBadAddress
+}
+
+// ReadFrom implements fs.FileOperations.ReadFrom.
+func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
+ f := &readerPayload{ctx: ctx, r: r, count: count}
+ n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{
+ // Reads may be destructive but should be very fast,
+ // so we can't release the lock while copying data.
+ Atomic: true,
+ })
+ if err == tcpip.ErrWouldBlock {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ if resCh != nil {
+ t := ctx.(*kernel.Task)
+ if err := t.Block(resCh); err != nil {
+ return 0, syserr.FromError(err).ToError()
+ }
+
+ n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{
+ Atomic: true, // See above.
+ })
+ }
+ if err == tcpip.ErrWouldBlock {
+ return n, syserror.ErrWouldBlock
+ } else if err != nil {
+ return int64(n), f.err // Propagate error.
+ }
+
+ return int64(n), nil
+}
+
// Readiness returns a mask of ready events for socket s.
func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
r := s.Endpoint.Readiness(mask)
@@ -646,7 +790,7 @@ func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
// tcpip.Endpoint.
func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
- // implemented specifically for epsocket.SocketOperations rather than
+ // implemented specifically for netstack.SocketOperations rather than
// commonEndpoint. commonEndpoint should be extended to support socket
// options where the implementation is not shared, as unix sockets need
// their own support for SO_TIMESTAMP.
@@ -719,7 +863,7 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int,
return getSockOptIPv6(t, ep, name, outLen)
case linux.SOL_IP:
- return getSockOptIP(t, ep, name, outLen)
+ return getSockOptIP(t, ep, name, outLen, family)
case linux.SOL_UDP,
linux.SOL_ICMPV6,
@@ -777,8 +921,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var size tcpip.SendBufferSizeOption
- if err := ep.GetSockOpt(&size); err != nil {
+ size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -793,8 +937,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var size tcpip.ReceiveBufferSizeOption
- if err := ep.GetSockOpt(&size); err != nil {
+ size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -828,6 +972,19 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return int32(v), nil
+ case linux.SO_BINDTODEVICE:
+ var v tcpip.BindToDeviceOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ if len(v) == 0 {
+ return []byte{}, nil
+ }
+ if outLen < linux.IFNAMSIZ {
+ return nil, syserr.ErrInvalidArgument
+ }
+ return append([]byte(v), 0), nil
+
case linux.SO_BROADCAST:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
@@ -900,8 +1057,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.DelayOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.DelayOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -970,6 +1127,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return int32(time.Duration(v) / time.Second), nil
+ case linux.TCP_USER_TIMEOUT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TCPUserTimeoutOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Millisecond), nil
+
case linux.TCP_INFO:
var v tcpip.TCPInfoOption
if err := ep.GetSockOpt(&v); err != nil {
@@ -1018,6 +1187,18 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
copy(b, v)
return b, nil
+ case linux.TCP_LINGER2:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TCPLingerTimeoutOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Second), nil
+
default:
emitUnimplementedEventTCP(t, name)
}
@@ -1042,6 +1223,25 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
+ case linux.IPV6_TCLASS:
+ // Length handling for parity with Linux.
+ if outLen == 0 {
+ return make([]byte, 0), nil
+ }
+ var v tcpip.IPv6TrafficClassOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ uintv := uint32(v)
+ // Linux truncates the output binary to outLen.
+ ib := binary.Marshal(nil, usermem.ByteOrder, &uintv)
+ // Handle cases where outLen is lesser than sizeOfInt32.
+ if len(ib) > outLen {
+ ib = ib[:outLen]
+ }
+ return ib, nil
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -1049,8 +1249,25 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
}
// getSockOptIP implements GetSockOpt when level is SOL_IP.
-func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) {
switch name {
+ case linux.IP_TTL:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TTLOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ // Fill in the default value, if needed.
+ if v == 0 {
+ v = DefaultTTL
+ }
+
+ return int32(v), nil
+
case linux.IP_MULTICAST_TTL:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
@@ -1092,6 +1309,20 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac
}
return int32(0), nil
+ case linux.IP_TOS:
+ // Length handling for parity with Linux.
+ if outLen == 0 {
+ return []byte(nil), nil
+ }
+ var v tcpip.IPv4TOSOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ if outLen < sizeOfInt32 {
+ return uint8(v), nil
+ }
+ return int32(v), nil
+
default:
emitUnimplementedEventIP(t, name)
}
@@ -1102,7 +1333,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfac
// tcpip.Endpoint.
func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
// TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
- // implemented specifically for epsocket.SocketOperations rather than
+ // implemented specifically for netstack.SocketOperations rather than
// commonEndpoint. commonEndpoint should be extended to support socket
// options where the implementation is not shared, as unix sockets need
// their own support for SO_TIMESTAMP.
@@ -1165,7 +1396,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v)))
case linux.SO_RCVBUF:
if len(optVal) < sizeOfInt32 {
@@ -1173,7 +1404,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v)))
case linux.SO_REUSEADDR:
if len(optVal) < sizeOfInt32 {
@@ -1191,6 +1422,13 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
v := usermem.ByteOrder.Uint32(optVal)
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v)))
+ case linux.SO_BINDTODEVICE:
+ n := bytes.IndexByte(optVal, 0)
+ if n == -1 {
+ n = len(optVal)
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n])))
+
case linux.SO_BROADCAST:
if len(optVal) < sizeOfInt32 {
return syserr.ErrInvalidArgument
@@ -1285,11 +1523,11 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- var o tcpip.DelayOption
+ var o int
if v == 0 {
o = 1
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(o))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o))
case linux.TCP_CORK:
if len(optVal) < sizeOfInt32 {
@@ -1337,6 +1575,17 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v))))
+ case linux.TCP_USER_TIMEOUT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+ if v < 0 {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v))))
+
case linux.TCP_CONGESTION:
v := tcpip.CongestionControlOption(optVal)
if err := ep.SetSockOpt(v); err != nil {
@@ -1344,6 +1593,14 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
return nil
+ case linux.TCP_LINGER2:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v))))
+
case linux.TCP_REPAIR_OPTIONS:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1383,6 +1640,19 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
t.Kernel().EmitUnimplementedEvent(t)
+ case linux.IPV6_TCLASS:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+ if v < -1 || v > 255 {
+ return syserr.ErrInvalidArgument
+ }
+ if v == -1 {
+ v = 0
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v)))
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -1514,6 +1784,30 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
t.Kernel().EmitUnimplementedEvent(t)
return syserr.ErrInvalidArgument
+ case linux.IP_TTL:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ // -1 means default TTL.
+ if v == -1 {
+ v = 0
+ } else if v < 1 || v > 255 {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v)))
+
+ case linux.IP_TOS:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v)))
+
case linux.IP_ADD_SOURCE_MEMBERSHIP,
linux.IP_BIND_ADDRESS_NO_PORT,
linux.IP_BLOCK_SOURCE,
@@ -1537,9 +1831,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
linux.IP_RECVTOS,
linux.IP_RECVTTL,
linux.IP_RETOPTS,
- linux.IP_TOS,
linux.IP_TRANSPARENT,
- linux.IP_TTL,
linux.IP_UNBLOCK_SOURCE,
linux.IP_UNICAST_IF,
linux.IP_XFRM_POLICY,
@@ -1724,12 +2016,14 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32)
return &out, uint32(2 + l)
}
return &out, uint32(3 + l)
+
case linux.AF_INET:
var out linux.SockAddrInet
copy(out.Addr[:], addr.Addr)
out.Family = linux.AF_INET
out.Port = htons(addr.Port)
- return &out, uint32(binary.Size(out))
+ return &out, uint32(sockAddrInetSize)
+
case linux.AF_INET6:
var out linux.SockAddrInet6
if len(addr.Addr) == 4 {
@@ -1745,7 +2039,17 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32)
if isLinkLocal(addr.Addr) {
out.Scope_id = uint32(addr.NIC)
}
- return &out, uint32(binary.Size(out))
+ return &out, uint32(sockAddrInet6Size)
+
+ case linux.AF_PACKET:
+ // TODO(b/129292371): Return protocol too.
+ var out linux.SockAddrLink
+ out.Family = linux.AF_PACKET
+ out.InterfaceIndex = int32(addr.NIC)
+ out.HardwareAddrLen = header.EthernetAddressSize
+ copy(out.HardwareAddr[:], addr.Addr)
+ return &out, uint32(sockAddrLinkSize)
+
default:
return nil, 0
}
@@ -2060,7 +2364,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
n, _, err = s.Endpoint.Write(v, opts)
}
dontWait := flags&linux.MSG_DONTWAIT != 0
- if err == nil && (n >= int64(v.Size()) || dontWait) {
+ if err == nil && (n >= v.src.NumBytes() || dontWait) {
// Complete write.
return int(n), nil
}
@@ -2085,7 +2389,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
return 0, syserr.TranslateNetstackError(err)
}
- if err == nil && v.Size() == 0 || err != nil && err != tcpip.ErrWouldBlock {
+ if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock {
return int(total), nil
}
@@ -2101,7 +2405,7 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
// Ioctl implements fs.FileOperations.Ioctl.
func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
- // SIOCGSTAMP is implemented by epsocket rather than all commonEndpoint
+ // SIOCGSTAMP is implemented by netstack rather than all commonEndpoint
// sockets.
// TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP.
switch args[1].Int() {
@@ -2207,9 +2511,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc
return 0, err
case linux.TIOCOUTQ:
- var v tcpip.SendQueueSizeOption
- if err := ep.GetSockOpt(&v); err != nil {
- return 0, syserr.TranslateNetstackError(err).ToError()
+ v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
}
if v > math.MaxInt32 {
@@ -2404,7 +2708,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
// Flag values and meanings are described in greater detail in netdevice(7) in
// the SIOCGIFFLAGS section.
func interfaceStatusFlags(stack inet.Stack, name string) (uint32, *syserr.Error) {
- // epsocket should only ever be passed an epsocket.Stack.
+ // We should only ever be passed a netstack.Stack.
epstack, ok := stack.(*Stack)
if !ok {
return 0, errStackType
diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/netstack/provider.go
index 421f93dc4..2d2c1ba2a 100644
--- a/pkg/sentry/socket/epsocket/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package epsocket
+package netstack
import (
"syscall"
@@ -62,10 +62,14 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
}
case linux.SOCK_RAW:
+ // TODO(b/142504697): "In order to create a raw socket, a
+ // process must have the CAP_NET_RAW capability in the user
+ // namespace that governs its network namespace." - raw(7)
+
// Raw sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(ctx)
if !creds.HasCapability(linux.CAP_NET_RAW) {
- return 0, true, syserr.ErrPermissionDenied
+ return 0, true, syserr.ErrNotPermitted
}
switch protocol {
@@ -85,7 +89,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
return 0, true, syserr.ErrProtocolNotSupported
}
-// Socket creates a new socket object for the AF_INET or AF_INET6 family.
+// Socket creates a new socket object for the AF_INET, AF_INET6, or AF_PACKET
+// family.
func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
// Fail right away if we don't have a stack.
stack := t.NetworkContext()
@@ -99,6 +104,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
return nil, nil
}
+ // Packet sockets are handled separately, since they are neither INET
+ // nor INET6 specific.
+ if p.family == linux.AF_PACKET {
+ return packetSocket(t, eps, stype, protocol)
+ }
+
// Figure out the transport protocol.
transProto, associated, err := getTransportProtocol(t, stype, protocol)
if err != nil {
@@ -121,12 +132,47 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
return New(t, p.family, stype, int(transProto), wq, ep)
}
+func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // TODO(b/142504697): "In order to create a packet socket, a process
+ // must have the CAP_NET_RAW capability in the user namespace that
+ // governs its network namespace." - packet(7)
+
+ // Packet sockets require CAP_NET_RAW.
+ creds := auth.CredentialsFromContext(t)
+ if !creds.HasCapability(linux.CAP_NET_RAW) {
+ return nil, syserr.ErrNotPermitted
+ }
+
+ // "cooked" packets don't contain link layer information.
+ var cooked bool
+ switch stype {
+ case linux.SOCK_DGRAM:
+ cooked = true
+ case linux.SOCK_RAW:
+ cooked = false
+ default:
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // protocol is passed in network byte order, but netstack wants it in
+ // host order.
+ netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+
+ wq := &waiter.Queue{}
+ ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return New(t, linux.AF_PACKET, stype, protocol, wq, ep)
+}
+
// Pair just returns nil sockets (not supported).
func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
return nil, nil, nil
}
-// init registers socket providers for AF_INET and AF_INET6.
+// init registers socket providers for AF_INET, AF_INET6, and AF_PACKET.
func init() {
// Providers backed by netstack.
p := []provider{
@@ -138,6 +184,9 @@ func init() {
family: linux.AF_INET6,
netProto: ipv6.ProtocolNumber,
},
+ {
+ family: linux.AF_PACKET,
+ },
}
for i := range p {
diff --git a/pkg/sentry/socket/epsocket/save_restore.go b/pkg/sentry/socket/netstack/save_restore.go
index f7b8c10cc..c7aaf722a 100644
--- a/pkg/sentry/socket/epsocket/save_restore.go
+++ b/pkg/sentry/socket/netstack/save_restore.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package epsocket
+package netstack
import (
"gvisor.dev/gvisor/pkg/tcpip/stack"
diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/netstack/stack.go
index 7cf7ff735..a0db2d4fd 100644
--- a/pkg/sentry/socket/epsocket/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-package epsocket
+package netstack
import (
"gvisor.dev/gvisor/pkg/abi/linux"
@@ -144,7 +144,98 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
// Statistics implements inet.Stack.Statistics.
func (s *Stack) Statistics(stat interface{}, arg string) error {
- return syserr.ErrEndpointOperation.ToError()
+ switch stats := stat.(type) {
+ case *inet.StatSNMPIP:
+ ip := Metrics.IP
+ *stats = inet.StatSNMPIP{
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL.
+ ip.PacketsReceived.Value(), // InReceives.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors.
+ ip.InvalidAddressesReceived.Value(), // InAddrErrors.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards.
+ ip.PacketsDelivered.Value(), // InDelivers.
+ ip.PacketsSent.Value(), // OutRequests.
+ ip.OutgoingPacketErrors.Value(), // OutDiscards.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails.
+ 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates.
+ }
+ case *inet.StatSNMPICMP:
+ in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats
+ out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats
+ *stats = inet.StatSNMPICMP{
+ 0, // TODO(gvisor.dev/issue/969): Support Icmp/InMsgs.
+ Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors.
+ 0, // TODO(gvisor.dev/issue/969): Support Icmp/InCsumErrors.
+ in.DstUnreachable.Value(), // InDestUnreachs.
+ in.TimeExceeded.Value(), // InTimeExcds.
+ in.ParamProblem.Value(), // InParmProbs.
+ in.SrcQuench.Value(), // InSrcQuenchs.
+ in.Redirect.Value(), // InRedirects.
+ in.Echo.Value(), // InEchos.
+ in.EchoReply.Value(), // InEchoReps.
+ in.Timestamp.Value(), // InTimestamps.
+ in.TimestampReply.Value(), // InTimestampReps.
+ in.InfoRequest.Value(), // InAddrMasks.
+ in.InfoReply.Value(), // InAddrMaskReps.
+ 0, // TODO(gvisor.dev/issue/969): Support Icmp/OutMsgs.
+ Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors.
+ out.DstUnreachable.Value(), // OutDestUnreachs.
+ out.TimeExceeded.Value(), // OutTimeExcds.
+ out.ParamProblem.Value(), // OutParmProbs.
+ out.SrcQuench.Value(), // OutSrcQuenchs.
+ out.Redirect.Value(), // OutRedirects.
+ out.Echo.Value(), // OutEchos.
+ out.EchoReply.Value(), // OutEchoReps.
+ out.Timestamp.Value(), // OutTimestamps.
+ out.TimestampReply.Value(), // OutTimestampReps.
+ out.InfoRequest.Value(), // OutAddrMasks.
+ out.InfoReply.Value(), // OutAddrMaskReps.
+ }
+ case *inet.StatSNMPTCP:
+ tcp := Metrics.TCP
+ // RFC 2012 (updates 1213): SNMPv2-MIB-TCP.
+ *stats = inet.StatSNMPTCP{
+ 1, // RtoAlgorithm.
+ 200, // RtoMin.
+ 120000, // RtoMax.
+ (1<<64 - 1), // MaxConn.
+ tcp.ActiveConnectionOpenings.Value(), // ActiveOpens.
+ tcp.PassiveConnectionOpenings.Value(), // PassiveOpens.
+ tcp.FailedConnectionAttempts.Value(), // AttemptFails.
+ tcp.EstablishedResets.Value(), // EstabResets.
+ tcp.CurrentEstablished.Value(), // CurrEstab.
+ tcp.ValidSegmentsReceived.Value(), // InSegs.
+ tcp.SegmentsSent.Value(), // OutSegs.
+ tcp.Retransmits.Value(), // RetransSegs.
+ tcp.InvalidSegmentsReceived.Value(), // InErrs.
+ tcp.ResetsSent.Value(), // OutRsts.
+ tcp.ChecksumErrors.Value(), // InCsumErrors.
+ }
+ case *inet.StatSNMPUDP:
+ udp := Metrics.UDP
+ *stats = inet.StatSNMPUDP{
+ udp.PacketsReceived.Value(), // InDatagrams.
+ udp.UnknownPortErrors.Value(), // NoPorts.
+ 0, // TODO(gvisor.dev/issue/969): Support Udp/InErrors.
+ udp.PacketsSent.Value(), // OutDatagrams.
+ udp.ReceiveBufferErrors.Value(), // RcvbufErrors.
+ 0, // TODO(gvisor.dev/issue/969): Support Udp/SndbufErrors.
+ 0, // TODO(gvisor.dev/issue/969): Support Udp/InCsumErrors.
+ 0, // TODO(gvisor.dev/issue/969): Support Udp/IgnoredMulti.
+ }
+ default:
+ return syserr.ErrEndpointOperation.ToError()
+ }
+ return nil
}
// RouteTable implements inet.Stack.RouteTable.
@@ -200,3 +291,18 @@ func (s *Stack) FillDefaultIPTables() {
func (s *Stack) Resume() {
s.Stack.Resume()
}
+
+// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint {
+ return s.Stack.RegisteredEndpoints()
+}
+
+// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint {
+ return s.Stack.CleanupEndpoints()
+}
+
+// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) {
+ s.Stack.RestoreCleanupEndpoints(es)
+}
diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD
index 5061dcbde..4668b87d1 100644
--- a/pkg/sentry/socket/rpcinet/BUILD
+++ b/pkg/sentry/socket/rpcinet/BUILD
@@ -1,5 +1,6 @@
load("//tools/go_stateify:defs.bzl", "go_library")
load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@rules_cc//cc:defs.bzl", "cc_proto_library")
package(licenses = ["notice"])
@@ -36,6 +37,7 @@ go_library(
"//pkg/syserror",
"//pkg/tcpip",
"//pkg/tcpip/buffer",
+ "//pkg/tcpip/stack",
"//pkg/unet",
"//pkg/waiter",
],
@@ -49,6 +51,14 @@ proto_library(
],
)
+cc_proto_library(
+ name = "syscall_rpc_cc_proto",
+ visibility = [
+ "//visibility:public",
+ ],
+ deps = [":syscall_rpc_proto"],
+)
+
go_proto_library(
name = "syscall_rpc_go_proto",
importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto",
diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go
index 5dcb6b455..f7878a760 100644
--- a/pkg/sentry/socket/rpcinet/stack.go
+++ b/pkg/sentry/socket/rpcinet/stack.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn"
"gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier"
"gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/unet"
)
@@ -165,3 +166,12 @@ func (s *Stack) RouteTable() []inet.Route {
// Resume implements inet.Stack.Resume.
func (s *Stack) Resume() {}
+
+// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil }
+
+// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil }
+
+// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
diff --git a/pkg/sentry/socket/rpcinet/syscall_rpc.proto b/pkg/sentry/socket/rpcinet/syscall_rpc.proto
index 9586f5923..b677e9eb3 100644
--- a/pkg/sentry/socket/rpcinet/syscall_rpc.proto
+++ b/pkg/sentry/socket/rpcinet/syscall_rpc.proto
@@ -3,7 +3,6 @@ syntax = "proto3";
// package syscall_rpc is a set of networking related system calls that can be
// forwarded to a socket gofer.
//
-// TODO(b/77963526): Document individual RPCs.
package syscall_rpc;
message SendmsgRequest {
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 8c250c325..2389a9cdb 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -43,6 +43,11 @@ type ControlMessages struct {
IP tcpip.ControlMessages
}
+// Release releases Unix domain socket credentials and rights.
+func (c *ControlMessages) Release() {
+ c.Unix.Release()
+}
+
// Socket is the interface containing socket syscalls used by the syscall layer
// to redirect them to the appropriate implementation.
type Socket interface {
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index da9977fde..5b6a154f6 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -1,7 +1,7 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
+package(licenses = ["notice"])
+
go_library(
name = "unix",
srcs = [
@@ -24,7 +24,7 @@ go_library(
"//pkg/sentry/safemem",
"//pkg/sentry/socket",
"//pkg/sentry/socket/control",
- "//pkg/sentry/socket/epsocket",
+ "//pkg/sentry/socket/netstack",
"//pkg/sentry/socket/unix/transport",
"//pkg/sentry/usermem",
"//pkg/syserr",
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 0b0240336..788ad70d2 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -1,8 +1,8 @@
-package(licenses = ["notice"])
-
load("//tools/go_stateify:defs.bzl", "go_library")
load("//tools/go_generics:defs.bzl", "go_template_instance")
+package(licenses = ["notice"])
+
go_template_instance(
name = "transport_message_list",
out = "transport_message_list.go",
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index 4bd15808a..dea11e253 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -220,6 +220,11 @@ func (e *connectionedEndpoint) Close() {
case e.Connected():
e.connected.CloseSend()
e.receiver.CloseRecv()
+ // Still have unread data? If yes, we set this into the write
+ // end so that the peer can get ECONNRESET) when it does read.
+ if e.receiver.RecvQueuedSize() > 0 {
+ e.connected.CloseUnread()
+ }
c = e.connected
r = e.receiver
e.connected = nil
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index 0415fae9a..e27b1c714 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -33,6 +33,7 @@ type queue struct {
mu sync.Mutex `state:"nosave"`
closed bool
+ unread bool
used int64
limit int64
dataList messageList
@@ -161,6 +162,9 @@ func (q *queue) Dequeue() (e *message, notify bool, err *syserr.Error) {
err := syserr.ErrWouldBlock
if q.closed {
err = syserr.ErrClosedForReceive
+ if q.unread {
+ err = syserr.ErrConnectionReset
+ }
}
q.mu.Unlock()
@@ -188,7 +192,9 @@ func (q *queue) Peek() (*message, *syserr.Error) {
if q.dataList.Front() == nil {
err := syserr.ErrWouldBlock
if q.closed {
- err = syserr.ErrClosedForReceive
+ if err = syserr.ErrClosedForReceive; q.unread {
+ err = syserr.ErrConnectionReset
+ }
}
return nil, err
}
@@ -208,3 +214,11 @@ func (q *queue) QueuedSize() int64 {
func (q *queue) MaxQueueSize() int64 {
return q.limit
}
+
+// CloseUnread sets flag to indicate that the peer is closed (not shutdown)
+// with unread data. So if read on this queue shall return ECONNRESET error.
+func (q *queue) CloseUnread() {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ q.unread = true
+}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 2b0ad6395..529a7a7a9 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -175,6 +175,10 @@ type Endpoint interface {
// types.
SetSockOpt(opt interface{}) *tcpip.Error
+ // SetSockOptInt sets a socket option for simple cases when a value has
+ // the int type.
+ SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error
+
// GetSockOpt gets a socket option. opt should be a pointer to one of the
// tcpip.*Option types.
GetSockOpt(opt interface{}) *tcpip.Error
@@ -604,6 +608,10 @@ type ConnectedEndpoint interface {
// Release releases any resources owned by the ConnectedEndpoint. It should
// be called before droping all references to a ConnectedEndpoint.
Release()
+
+ // CloseUnread sets the fact that this end is closed with unread data to
+ // the peer socket.
+ CloseUnread()
}
// +stateify savable
@@ -707,6 +715,11 @@ func (e *connectedEndpoint) Release() {
e.writeQueue.DecRef()
}
+// CloseUnread implements ConnectedEndpoint.CloseUnread.
+func (e *connectedEndpoint) CloseUnread() {
+ e.writeQueue.CloseUnread()
+}
+
// baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless
// unix domain socket Endpoint implementations.
//
@@ -838,6 +851,10 @@ func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
return nil
}
+func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error {
+ return nil
+}
+
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
switch opt {
case tcpip.ReceiveQueueSizeOption:
@@ -853,65 +870,63 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) {
return -1, tcpip.ErrQueueSizeNotSupported
}
return v, nil
- default:
- return -1, tcpip.ErrUnknownProtocolOption
- }
-}
-// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
- case tcpip.ErrorOption:
- return nil
-
- case *tcpip.SendQueueSizeOption:
+ case tcpip.SendQueueSizeOption:
e.Lock()
if !e.Connected() {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.SendQueueSizeOption(e.connected.SendQueuedSize())
+ v := e.connected.SendQueuedSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
- }
- *o = qs
- return nil
-
- case *tcpip.PasscredOption:
- if e.Passcred() {
- *o = tcpip.PasscredOption(1)
- } else {
- *o = tcpip.PasscredOption(0)
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
}
- return nil
+ return int(v), nil
- case *tcpip.SendBufferSizeOption:
+ case tcpip.SendBufferSizeOption:
e.Lock()
if !e.Connected() {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.SendBufferSizeOption(e.connected.SendMaxQueueSize())
+ v := e.connected.SendMaxQueueSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
}
- *o = qs
- return nil
+ return int(v), nil
- case *tcpip.ReceiveBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
e.Lock()
if e.receiver == nil {
e.Unlock()
- return tcpip.ErrNotConnected
+ return -1, tcpip.ErrNotConnected
}
- qs := tcpip.ReceiveBufferSizeOption(e.receiver.RecvMaxQueueSize())
+ v := e.receiver.RecvMaxQueueSize()
e.Unlock()
- if qs < 0 {
- return tcpip.ErrQueueSizeNotSupported
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
+ }
+ return int(v), nil
+
+ default:
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch o := opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ case *tcpip.PasscredOption:
+ if e.Passcred() {
+ *o = tcpip.PasscredOption(1)
+ } else {
+ *o = tcpip.PasscredOption(0)
}
- *o = qs
return nil
case *tcpip.KeepaliveEnabledOption:
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 0d0cb68df..885758054 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -31,7 +31,7 @@ import (
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/control"
- "gvisor.dev/gvisor/pkg/sentry/socket/epsocket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/sentry/usermem"
"gvisor.dev/gvisor/pkg/syserr"
@@ -40,8 +40,8 @@ import (
"gvisor.dev/gvisor/pkg/waiter"
)
-// SocketOperations is a Unix socket. It is similar to an epsocket, except it
-// is backed by a transport.Endpoint instead of a tcpip.Endpoint.
+// SocketOperations is a Unix socket. It is similar to a netstack socket,
+// except it is backed by a transport.Endpoint instead of a tcpip.Endpoint.
//
// +stateify savable
type SocketOperations struct {
@@ -116,8 +116,11 @@ func (s *SocketOperations) Endpoint() transport.Endpoint {
// extractPath extracts and validates the address.
func extractPath(sockaddr []byte) (string, *syserr.Error) {
- addr, _, err := epsocket.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */)
+ addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */)
if err != nil {
+ if err == syserr.ErrAddressFamilyNotSupported {
+ err = syserr.ErrInvalidArgument
+ }
return "", err
}
@@ -143,7 +146,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32,
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr)
+ a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
return a, l, nil
}
@@ -155,19 +158,19 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32,
return nil, 0, syserr.TranslateNetstackError(err)
}
- a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr)
+ a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
return a, l, nil
}
// Ioctl implements fs.FileOperations.Ioctl.
func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
- return epsocket.Ioctl(ctx, s.ep, io, args)
+ return netstack.Ioctl(ctx, s.ep, io, args)
}
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
- return epsocket.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
+ return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
}
// Listen implements the linux syscall listen(2) for sockets backed by
@@ -474,13 +477,13 @@ func (s *SocketOperations) EventUnregister(e *waiter.Entry) {
// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
- return epsocket.SetSockOpt(t, s, s.ep, level, name, optVal)
+ return netstack.SetSockOpt(t, s, s.ep, level, name, optVal)
}
// Shutdown implements the linux syscall shutdown(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
- f, err := epsocket.ConvertShutdown(how)
+ f, err := netstack.ConvertShutdown(how)
if err != nil {
return err
}
@@ -546,7 +549,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var from linux.SockAddr
var fromLen uint32
if r.From != nil && len([]byte(r.From.Addr)) != 0 {
- from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
+ from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
}
if r.ControlTrunc {
@@ -581,7 +584,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
var from linux.SockAddr
var fromLen uint32
if r.From != nil {
- from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
+ from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
}
if r.ControlTrunc {
@@ -595,7 +598,8 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
total += n
}
- if err != nil || !waitAll || isPacket || n >= dst.NumBytes() {
+ streamPeerClosed := s.stype == linux.SOCK_STREAM && n == 0 && err == nil
+ if err != nil || !waitAll || isPacket || n >= dst.NumBytes() || streamPeerClosed {
if total > 0 {
err = nil
}