summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/BUILD24
-rw-r--r--pkg/sentry/socket/control/BUILD29
-rw-r--r--pkg/sentry/socket/control/control.go591
-rw-r--r--pkg/sentry/socket/control/control_vfs2.go131
-rw-r--r--pkg/sentry/socket/hostinet/BUILD45
-rw-r--r--pkg/sentry/socket/hostinet/device.go19
-rw-r--r--pkg/sentry/socket/hostinet/hostinet.go17
-rw-r--r--pkg/sentry/socket/hostinet/save_restore.go20
-rw-r--r--pkg/sentry/socket/hostinet/socket.go713
-rw-r--r--pkg/sentry/socket/hostinet/socket_unsafe.go139
-rw-r--r--pkg/sentry/socket/hostinet/socket_vfs2.go202
-rw-r--r--pkg/sentry/socket/hostinet/sockopt_impl.go27
-rw-r--r--pkg/sentry/socket/hostinet/stack.go459
-rw-r--r--pkg/sentry/socket/netfilter/BUILD29
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go95
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go761
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go149
-rw-r--r--pkg/sentry/socket/netfilter/targets.go35
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go130
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go129
-rw-r--r--pkg/sentry/socket/netlink/BUILD52
-rw-r--r--pkg/sentry/socket/netlink/message.go281
-rw-r--r--pkg/sentry/socket/netlink/message_test.go312
-rw-r--r--pkg/sentry/socket/netlink/port/BUILD16
-rw-r--r--pkg/sentry/socket/netlink/port/port.go117
-rw-r--r--pkg/sentry/socket/netlink/port/port_test.go82
-rw-r--r--pkg/sentry/socket/netlink/provider.go116
-rw-r--r--pkg/sentry/socket/netlink/provider_vfs2.go69
-rw-r--r--pkg/sentry/socket/netlink/route/BUILD20
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go498
-rw-r--r--pkg/sentry/socket/netlink/socket.go780
-rw-r--r--pkg/sentry/socket/netlink/socket_vfs2.go152
-rw-r--r--pkg/sentry/socket/netlink/uevent/BUILD16
-rw-r--r--pkg/sentry/socket/netlink/uevent/protocol.go60
-rw-r--r--pkg/sentry/socket/netstack/BUILD56
-rw-r--r--pkg/sentry/socket/netstack/device.go20
-rw-r--r--pkg/sentry/socket/netstack/netstack.go3143
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go330
-rw-r--r--pkg/sentry/socket/netstack/provider.go199
-rw-r--r--pkg/sentry/socket/netstack/provider_vfs2.go141
-rw-r--r--pkg/sentry/socket/netstack/save_restore.go27
-rw-r--r--pkg/sentry/socket/netstack/stack.go386
-rw-r--r--pkg/sentry/socket/socket.go461
-rw-r--r--pkg/sentry/socket/unix/BUILD39
-rw-r--r--pkg/sentry/socket/unix/device.go20
-rw-r--r--pkg/sentry/socket/unix/io.go111
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD41
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go486
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned_state.go53
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go218
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go247
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go1006
-rw-r--r--pkg/sentry/socket/unix/unix.go772
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go371
54 files changed, 14442 insertions, 0 deletions
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
new file mode 100644
index 000000000..c40c6d673
--- /dev/null
+++ b/pkg/sentry/socket/BUILD
@@ -0,0 +1,24 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "socket",
+ srcs = ["socket.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
+ "//pkg/syserr",
+ "//pkg/tcpip",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
new file mode 100644
index 000000000..ca16d0381
--- /dev/null
+++ b/pkg/sentry/socket/control/BUILD
@@ -0,0 +1,29 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "control",
+ srcs = [
+ "control.go",
+ "control_vfs2.go",
+ ],
+ imports = [
+ "gvisor.dev/gvisor/pkg/sentry/fs",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
new file mode 100644
index 000000000..8b439a078
--- /dev/null
+++ b/pkg/sentry/socket/control/control.go
@@ -0,0 +1,591 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package control provides internal representations of socket control
+// messages.
+package control
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const maxInt = int(^uint(0) >> 1)
+
+// SCMCredentials represents a SCM_CREDENTIALS socket control message.
+type SCMCredentials interface {
+ transport.CredentialsControlMessage
+
+ // Credentials returns properly namespaced values for the sender's pid, uid
+ // and gid.
+ Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, auth.GID)
+}
+
+// LINT.IfChange
+
+// SCMRights represents a SCM_RIGHTS socket control message.
+type SCMRights interface {
+ transport.RightsControlMessage
+
+ // Files returns up to max RightsFiles.
+ //
+ // Returned files are consumed and ownership is transferred to the caller.
+ // Subsequent calls to Files will return the next files.
+ Files(ctx context.Context, max int) (rf RightsFiles, truncated bool)
+}
+
+// RightsFiles represents a SCM_RIGHTS socket control message. A reference is
+// maintained for each fs.File and is release either when an FD is created or
+// when the Release method is called.
+//
+// +stateify savable
+type RightsFiles []*fs.File
+
+// NewSCMRights creates a new SCM_RIGHTS socket control message representation
+// using local sentry FDs.
+func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) {
+ files := make(RightsFiles, 0, len(fds))
+ for _, fd := range fds {
+ file := t.GetFile(fd)
+ if file == nil {
+ files.Release()
+ return nil, syserror.EBADF
+ }
+ files = append(files, file)
+ }
+ return &files, nil
+}
+
+// Files implements SCMRights.Files.
+func (fs *RightsFiles) Files(ctx context.Context, max int) (RightsFiles, bool) {
+ n := max
+ var trunc bool
+ if l := len(*fs); n > l {
+ n = l
+ } else if n < l {
+ trunc = true
+ }
+ rf := (*fs)[:n]
+ *fs = (*fs)[n:]
+ return rf, trunc
+}
+
+// Clone implements transport.RightsControlMessage.Clone.
+func (fs *RightsFiles) Clone() transport.RightsControlMessage {
+ nfs := append(RightsFiles(nil), *fs...)
+ for _, nf := range nfs {
+ nf.IncRef()
+ }
+ return &nfs
+}
+
+// Release implements transport.RightsControlMessage.Release.
+func (fs *RightsFiles) Release() {
+ for _, f := range *fs {
+ f.DecRef()
+ }
+ *fs = nil
+}
+
+// rightsFDs gets up to the specified maximum number of FDs.
+func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32, bool) {
+ files, trunc := rights.Files(t, max)
+ fds := make([]int32, 0, len(files))
+ for i := 0; i < max && len(files) > 0; i++ {
+ fd, err := t.NewFDFrom(0, files[0], kernel.FDFlags{
+ CloseOnExec: cloexec,
+ })
+ files[0].DecRef()
+ files = files[1:]
+ if err != nil {
+ t.Warningf("Error inserting FD: %v", err)
+ // This is what Linux does.
+ break
+ }
+
+ fds = append(fds, int32(fd))
+ }
+ return fds, trunc
+}
+
+// PackRights packs as many FDs as will fit into the unused capacity of buf.
+func PackRights(t *kernel.Task, rights SCMRights, cloexec bool, buf []byte, flags int) ([]byte, int) {
+ maxFDs := (cap(buf) - len(buf) - linux.SizeOfControlMessageHeader) / 4
+ // Linux does not return any FDs if none fit.
+ if maxFDs <= 0 {
+ flags |= linux.MSG_CTRUNC
+ return buf, flags
+ }
+ fds, trunc := rightsFDs(t, rights, cloexec, maxFDs)
+ if trunc {
+ flags |= linux.MSG_CTRUNC
+ }
+ align := t.Arch().Width()
+ return putCmsg(buf, flags, linux.SCM_RIGHTS, align, fds)
+}
+
+// LINT.ThenChange(./control_vfs2.go)
+
+// scmCredentials represents an SCM_CREDENTIALS socket control message.
+//
+// +stateify savable
+type scmCredentials struct {
+ t *kernel.Task
+ kuid auth.KUID
+ kgid auth.KGID
+}
+
+// NewSCMCredentials creates a new SCM_CREDENTIALS socket control message
+// representation.
+func NewSCMCredentials(t *kernel.Task, cred linux.ControlMessageCredentials) (SCMCredentials, error) {
+ tcred := t.Credentials()
+ kuid, err := tcred.UseUID(auth.UID(cred.UID))
+ if err != nil {
+ return nil, err
+ }
+ kgid, err := tcred.UseGID(auth.GID(cred.GID))
+ if err != nil {
+ return nil, err
+ }
+ if kernel.ThreadID(cred.PID) != t.ThreadGroup().ID() && !t.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.PIDNamespace().UserNamespace()) {
+ return nil, syserror.EPERM
+ }
+ return &scmCredentials{t, kuid, kgid}, nil
+}
+
+// Equals implements transport.CredentialsControlMessage.Equals.
+func (c *scmCredentials) Equals(oc transport.CredentialsControlMessage) bool {
+ if oc, _ := oc.(*scmCredentials); oc != nil && *c == *oc {
+ return true
+ }
+ return false
+}
+
+func putUint64(buf []byte, n uint64) []byte {
+ usermem.ByteOrder.PutUint64(buf[len(buf):len(buf)+8], n)
+ return buf[:len(buf)+8]
+}
+
+func putUint32(buf []byte, n uint32) []byte {
+ usermem.ByteOrder.PutUint32(buf[len(buf):len(buf)+4], n)
+ return buf[:len(buf)+4]
+}
+
+// putCmsg writes a control message header and as much data as will fit into
+// the unused capacity of a buffer.
+func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) {
+ space := binary.AlignDown(cap(buf)-len(buf), 4)
+
+ // We can't write to space that doesn't exist, so if we are going to align
+ // the available space, we must align down.
+ //
+ // align must be >= 4 and each data int32 is 4 bytes. The length of the
+ // header is already aligned, so if we align to the width of the data there
+ // are two cases:
+ // 1. The aligned length is less than the length of the header. The
+ // unaligned length was also less than the length of the header, so we
+ // can't write anything.
+ // 2. The aligned length is greater than or equal to the length of the
+ // header. We can write the header plus zero or more bytes of data. We can't
+ // write a partial int32, so the length of the message will be
+ // min(aligned length, header + data).
+ if space < linux.SizeOfControlMessageHeader {
+ flags |= linux.MSG_CTRUNC
+ return buf, flags
+ }
+
+ length := 4*len(data) + linux.SizeOfControlMessageHeader
+ if length > space {
+ length = space
+ }
+ buf = putUint64(buf, uint64(length))
+ buf = putUint32(buf, linux.SOL_SOCKET)
+ buf = putUint32(buf, msgType)
+ for _, d := range data {
+ if len(buf)+4 > cap(buf) {
+ flags |= linux.MSG_CTRUNC
+ break
+ }
+ buf = putUint32(buf, uint32(d))
+ }
+ return alignSlice(buf, align), flags
+}
+
+func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interface{}) []byte {
+ if cap(buf)-len(buf) < linux.SizeOfControlMessageHeader {
+ return buf
+ }
+ ob := buf
+
+ buf = putUint64(buf, uint64(linux.SizeOfControlMessageHeader))
+ buf = putUint32(buf, msgLevel)
+ buf = putUint32(buf, msgType)
+
+ hdrBuf := buf
+
+ buf = binary.Marshal(buf, usermem.ByteOrder, data)
+
+ // If the control message data brought us over capacity, omit it.
+ if cap(buf) != cap(ob) {
+ return hdrBuf
+ }
+
+ // Update control message length to include data.
+ putUint64(ob, uint64(len(buf)-len(ob)))
+
+ return alignSlice(buf, align)
+}
+
+// Credentials implements SCMCredentials.Credentials.
+func (c *scmCredentials) Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) {
+ // "When a process's user and group IDs are passed over a UNIX domain
+ // socket to a process in a different user namespace (see the description
+ // of SCM_CREDENTIALS in unix(7)), they are translated into the
+ // corresponding values as per the receiving process's user and group ID
+ // mappings." - user_namespaces(7)
+ pid := t.PIDNamespace().IDOfTask(c.t)
+ uid := c.kuid.In(t.UserNamespace()).OrOverflow()
+ gid := c.kgid.In(t.UserNamespace()).OrOverflow()
+
+ return pid, uid, gid
+}
+
+// PackCredentials packs the credentials in the control message (or default
+// credentials if none) into a buffer.
+func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int) ([]byte, int) {
+ align := t.Arch().Width()
+
+ // Default credentials if none are available.
+ pid := kernel.ThreadID(0)
+ uid := auth.UID(auth.NobodyKUID)
+ gid := auth.GID(auth.NobodyKGID)
+
+ if creds != nil {
+ pid, uid, gid = creds.Credentials(t)
+ }
+ c := []int32{int32(pid), int32(uid), int32(gid)}
+ return putCmsg(buf, flags, linux.SCM_CREDENTIALS, align, c)
+}
+
+// alignSlice extends a slice's length (up to the capacity) to align it.
+func alignSlice(buf []byte, align uint) []byte {
+ aligned := binary.AlignUp(len(buf), align)
+ if aligned > cap(buf) {
+ // Linux allows unaligned data if there isn't room for alignment.
+ // Since there isn't room for alignment, there isn't room for any
+ // additional messages either.
+ return buf
+ }
+ return buf[:aligned]
+}
+
+// PackTimestamp packs a SO_TIMESTAMP socket control message.
+func PackTimestamp(t *kernel.Task, timestamp int64, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_SOCKET,
+ linux.SO_TIMESTAMP,
+ t.Arch().Width(),
+ linux.NsecToTimeval(timestamp),
+ )
+}
+
+// PackInq packs a TCP_INQ socket control message.
+func PackInq(t *kernel.Task, inq int32, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_TCP,
+ linux.TCP_INQ,
+ t.Arch().Width(),
+ inq,
+ )
+}
+
+// PackTOS packs an IP_TOS socket control message.
+func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IP,
+ linux.IP_TOS,
+ t.Arch().Width(),
+ tos,
+ )
+}
+
+// PackTClass packs an IPV6_TCLASS socket control message.
+func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IPV6,
+ linux.IPV6_TCLASS,
+ t.Arch().Width(),
+ tClass,
+ )
+}
+
+// PackIPPacketInfo packs an IP_PKTINFO socket control message.
+func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
+ var p linux.ControlMessageIPPacketInfo
+ p.NIC = int32(packetInfo.NIC)
+ copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
+ copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IP,
+ linux.IP_PKTINFO,
+ t.Arch().Width(),
+ p,
+ )
+}
+
+// PackControlMessages packs control messages into the given buffer.
+//
+// We skip control messages specific to Unix domain sockets.
+//
+// Note that some control messages may be truncated if they do not fit under
+// the capacity of buf.
+func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte {
+ if cmsgs.IP.HasTimestamp {
+ buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf)
+ }
+
+ if cmsgs.IP.HasInq {
+ // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP.
+ buf = PackInq(t, cmsgs.IP.Inq, buf)
+ }
+
+ if cmsgs.IP.HasTOS {
+ buf = PackTOS(t, cmsgs.IP.TOS, buf)
+ }
+
+ if cmsgs.IP.HasTClass {
+ buf = PackTClass(t, cmsgs.IP.TClass, buf)
+ }
+
+ if cmsgs.IP.HasIPPacketInfo {
+ buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ }
+
+ return buf
+}
+
+// cmsgSpace is equivalent to CMSG_SPACE in Linux.
+func cmsgSpace(t *kernel.Task, dataLen int) int {
+ return linux.SizeOfControlMessageHeader + binary.AlignUp(dataLen, t.Arch().Width())
+}
+
+// CmsgsSpace returns the number of bytes needed to fit the control messages
+// represented in cmsgs.
+func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
+ space := 0
+
+ if cmsgs.IP.HasTimestamp {
+ space += cmsgSpace(t, linux.SizeOfTimeval)
+ }
+
+ if cmsgs.IP.HasInq {
+ space += cmsgSpace(t, linux.SizeOfControlMessageInq)
+ }
+
+ if cmsgs.IP.HasTOS {
+ space += cmsgSpace(t, linux.SizeOfControlMessageTOS)
+ }
+
+ if cmsgs.IP.HasTClass {
+ space += cmsgSpace(t, linux.SizeOfControlMessageTClass)
+ }
+
+ return space
+}
+
+// NewIPPacketInfo returns the IPPacketInfo struct.
+func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
+ var p tcpip.IPPacketInfo
+ p.NIC = tcpip.NICID(packetInfo.NIC)
+ copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
+ copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+
+ return p
+}
+
+// Parse parses a raw socket control message into portable objects.
+func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) {
+ var (
+ cmsgs socket.ControlMessages
+ fds linux.ControlMessageRights
+ )
+
+ for i := 0; i < len(buf); {
+ if i+linux.SizeOfControlMessageHeader > len(buf) {
+ return cmsgs, syserror.EINVAL
+ }
+
+ var h linux.ControlMessageHeader
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h)
+
+ if h.Length < uint64(linux.SizeOfControlMessageHeader) {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ if h.Length > uint64(len(buf)-i) {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ i += linux.SizeOfControlMessageHeader
+ length := int(h.Length) - linux.SizeOfControlMessageHeader
+
+ // The use of t.Arch().Width() is analogous to Linux's use of
+ // sizeof(long) in CMSG_ALIGN.
+ width := t.Arch().Width()
+
+ switch h.Level {
+ case linux.SOL_SOCKET:
+ switch h.Type {
+ case linux.SCM_RIGHTS:
+ rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
+ numRights := rightsSize / linux.SizeOfControlMessageRight
+
+ if len(fds)+numRights > linux.SCM_MAX_FD {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight {
+ fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
+ }
+
+ i += binary.AlignUp(length, width)
+
+ case linux.SCM_CREDENTIALS:
+ if length < linux.SizeOfControlMessageCredentials {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ var creds linux.ControlMessageCredentials
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds)
+ scmCreds, err := NewSCMCredentials(t, creds)
+ if err != nil {
+ return socket.ControlMessages{}, err
+ }
+ cmsgs.Unix.Credentials = scmCreds
+ i += binary.AlignUp(length, width)
+
+ default:
+ // Unknown message type.
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ case linux.SOL_IP:
+ switch h.Type {
+ case linux.IP_TOS:
+ if length < linux.SizeOfControlMessageTOS {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ cmsgs.IP.HasTOS = true
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS)
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_PKTINFO:
+ if length < linux.SizeOfControlMessageIPPacketInfo {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ cmsgs.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+
+ cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ i += binary.AlignUp(length, width)
+
+ default:
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ case linux.SOL_IPV6:
+ switch h.Type {
+ case linux.IPV6_TCLASS:
+ if length < linux.SizeOfControlMessageTClass {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ cmsgs.IP.HasTClass = true
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
+ i += binary.AlignUp(length, width)
+
+ default:
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ default:
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+ }
+
+ if cmsgs.Unix.Credentials == nil {
+ cmsgs.Unix.Credentials = makeCreds(t, socketOrEndpoint)
+ }
+
+ if len(fds) > 0 {
+ if kernel.VFS2Enabled {
+ rights, err := NewSCMRightsVFS2(t, fds)
+ if err != nil {
+ return socket.ControlMessages{}, err
+ }
+ cmsgs.Unix.Rights = rights
+ } else {
+ rights, err := NewSCMRights(t, fds)
+ if err != nil {
+ return socket.ControlMessages{}, err
+ }
+ cmsgs.Unix.Rights = rights
+ }
+ }
+
+ return cmsgs, nil
+}
+
+func makeCreds(t *kernel.Task, socketOrEndpoint interface{}) SCMCredentials {
+ if t == nil || socketOrEndpoint == nil {
+ return nil
+ }
+ if cr, ok := socketOrEndpoint.(transport.Credentialer); ok && (cr.Passcred() || cr.ConnectedPasscred()) {
+ return MakeCreds(t)
+ }
+ return nil
+}
+
+// MakeCreds creates default SCMCredentials.
+func MakeCreds(t *kernel.Task) SCMCredentials {
+ if t == nil {
+ return nil
+ }
+ tcred := t.Credentials()
+ return &scmCredentials{t, tcred.EffectiveKUID, tcred.EffectiveKGID}
+}
+
+// LINT.IfChange
+
+// New creates default control messages if needed.
+func New(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRights) transport.ControlMessages {
+ return transport.ControlMessages{
+ Credentials: makeCreds(t, socketOrEndpoint),
+ Rights: rights,
+ }
+}
+
+// LINT.ThenChange(./control_vfs2.go)
diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go
new file mode 100644
index 000000000..fd08179be
--- /dev/null
+++ b/pkg/sentry/socket/control/control_vfs2.go
@@ -0,0 +1,131 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserror"
+)
+
+// SCMRightsVFS2 represents a SCM_RIGHTS socket control message.
+type SCMRightsVFS2 interface {
+ transport.RightsControlMessage
+
+ // Files returns up to max RightsFiles.
+ //
+ // Returned files are consumed and ownership is transferred to the caller.
+ // Subsequent calls to Files will return the next files.
+ Files(ctx context.Context, max int) (rf RightsFilesVFS2, truncated bool)
+}
+
+// RightsFiles represents a SCM_RIGHTS socket control message. A reference is
+// maintained for each vfs.FileDescription and is release either when an FD is created or
+// when the Release method is called.
+type RightsFilesVFS2 []*vfs.FileDescription
+
+// NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message
+// representation using local sentry FDs.
+func NewSCMRightsVFS2(t *kernel.Task, fds []int32) (SCMRightsVFS2, error) {
+ files := make(RightsFilesVFS2, 0, len(fds))
+ for _, fd := range fds {
+ file := t.GetFileVFS2(fd)
+ if file == nil {
+ files.Release()
+ return nil, syserror.EBADF
+ }
+ files = append(files, file)
+ }
+ return &files, nil
+}
+
+// Files implements SCMRights.Files.
+func (fs *RightsFilesVFS2) Files(ctx context.Context, max int) (RightsFilesVFS2, bool) {
+ n := max
+ var trunc bool
+ if l := len(*fs); n > l {
+ n = l
+ } else if n < l {
+ trunc = true
+ }
+ rf := (*fs)[:n]
+ *fs = (*fs)[n:]
+ return rf, trunc
+}
+
+// Clone implements transport.RightsControlMessage.Clone.
+func (fs *RightsFilesVFS2) Clone() transport.RightsControlMessage {
+ nfs := append(RightsFilesVFS2(nil), *fs...)
+ for _, nf := range nfs {
+ nf.IncRef()
+ }
+ return &nfs
+}
+
+// Release implements transport.RightsControlMessage.Release.
+func (fs *RightsFilesVFS2) Release() {
+ for _, f := range *fs {
+ f.DecRef()
+ }
+ *fs = nil
+}
+
+// rightsFDsVFS2 gets up to the specified maximum number of FDs.
+func rightsFDsVFS2(t *kernel.Task, rights SCMRightsVFS2, cloexec bool, max int) ([]int32, bool) {
+ files, trunc := rights.Files(t, max)
+ fds := make([]int32, 0, len(files))
+ for i := 0; i < max && len(files) > 0; i++ {
+ fd, err := t.NewFDFromVFS2(0, files[0], kernel.FDFlags{
+ CloseOnExec: cloexec,
+ })
+ files[0].DecRef()
+ files = files[1:]
+ if err != nil {
+ t.Warningf("Error inserting FD: %v", err)
+ // This is what Linux does.
+ break
+ }
+
+ fds = append(fds, int32(fd))
+ }
+ return fds, trunc
+}
+
+// PackRightsVFS2 packs as many FDs as will fit into the unused capacity of buf.
+func PackRightsVFS2(t *kernel.Task, rights SCMRightsVFS2, cloexec bool, buf []byte, flags int) ([]byte, int) {
+ maxFDs := (cap(buf) - len(buf) - linux.SizeOfControlMessageHeader) / 4
+ // Linux does not return any FDs if none fit.
+ if maxFDs <= 0 {
+ flags |= linux.MSG_CTRUNC
+ return buf, flags
+ }
+ fds, trunc := rightsFDsVFS2(t, rights, cloexec, maxFDs)
+ if trunc {
+ flags |= linux.MSG_CTRUNC
+ }
+ align := t.Arch().Width()
+ return putCmsg(buf, flags, linux.SCM_RIGHTS, align, fds)
+}
+
+// NewVFS2 creates default control messages if needed.
+func NewVFS2(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRightsVFS2) transport.ControlMessages {
+ return transport.ControlMessages{
+ Credentials: makeCreds(t, socketOrEndpoint),
+ Rights: rights,
+ }
+}
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
new file mode 100644
index 000000000..ff81ea6e6
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -0,0 +1,45 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "hostinet",
+ srcs = [
+ "device.go",
+ "hostinet.go",
+ "save_restore.go",
+ "socket.go",
+ "socket_unsafe.go",
+ "socket_vfs2.go",
+ "sockopt_impl.go",
+ "stack.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/fdnotifier",
+ "//pkg/log",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/sockfs",
+ "//pkg/sentry/hostfd",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/vfs",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip/stack",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/socket/hostinet/device.go b/pkg/sentry/socket/hostinet/device.go
new file mode 100644
index 000000000..27049d65f
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/device.go
@@ -0,0 +1,19 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import "gvisor.dev/gvisor/pkg/sentry/device"
+
+var socketDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/hostinet/hostinet.go b/pkg/sentry/socket/hostinet/hostinet.go
new file mode 100644
index 000000000..0d6f51d2b
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/hostinet.go
@@ -0,0 +1,17 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package hostinet implements AF_INET and AF_INET6 sockets using the host's
+// network stack.
+package hostinet
diff --git a/pkg/sentry/socket/hostinet/save_restore.go b/pkg/sentry/socket/hostinet/save_restore.go
new file mode 100644
index 000000000..1dec33897
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/save_restore.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+// beforeSave is invoked by stateify.
+func (*socketOperations) beforeSave() {
+ panic("host.socketOperations is not savable")
+}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
new file mode 100644
index 000000000..a92aed2c9
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -0,0 +1,713 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "fmt"
+ "syscall"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const (
+ sizeofInt32 = 4
+
+ // sizeofSockaddr is the size in bytes of the largest sockaddr type
+ // supported by this package.
+ sizeofSockaddr = syscall.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in)
+
+ // maxControlLen is the maximum size of a control message buffer used in a
+ // recvmsg or sendmsg syscall.
+ maxControlLen = 1024
+)
+
+// LINT.IfChange
+
+// socketOperations implements fs.FileOperations and socket.Socket for a socket
+// implemented using a host socket.
+type socketOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ socketOpsCommon
+}
+
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
+ socket.SendReceiveTimeout
+
+ family int // Read-only.
+ stype linux.SockType // Read-only.
+ protocol int // Read-only.
+ queue waiter.Queue
+
+ // fd is the host socket fd. It must have O_NONBLOCK, so that operations
+ // will return EWOULDBLOCK instead of blocking on the host. This allows us to
+ // handle blocking behavior independently in the sentry.
+ fd int
+}
+
+var _ = socket.Socket(&socketOperations{})
+
+func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) {
+ s := &socketOperations{
+ socketOpsCommon: socketOpsCommon{
+ family: family,
+ stype: stype,
+ protocol: protocol,
+ fd: fd,
+ },
+ }
+ if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ dirent := socket.NewDirent(ctx, socketDevice)
+ defer dirent.DecRef()
+ return fs.NewFile(ctx, dirent, fs.FileFlags{NonBlocking: nonblock, Read: true, Write: true, NonSeekable: true}, s), nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *socketOpsCommon) Release() {
+ fdnotifier.RemoveFD(int32(s.fd))
+ syscall.Close(s.fd)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fdnotifier.NonBlockingPoll(int32(s.fd), mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.queue.EventRegister(e, mask)
+ fdnotifier.UpdateFD(int32(s.fd))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
+ s.queue.EventUnregister(e)
+ fdnotifier.UpdateFD(int32(s.fd))
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return ioctl(ctx, s.fd, io, args)
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of dst.Addrs was unusable.
+ if uint64(dst.NumBytes()) != dsts.NumBytes() {
+ return 0, nil
+ }
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+ if dsts.NumBlocks() == 1 {
+ // Skip allocating []syscall.Iovec.
+ n, err := syscall.Read(s.fd, dsts.Head().ToSlice())
+ if err != nil {
+ return 0, translateIOSyscallError(err)
+ }
+ return uint64(n), nil
+ }
+ return readv(s.fd, safemem.IovecsFromBlockSeq(dsts))
+ }))
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ n, err := src.CopyInTo(ctx, safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of src.Addrs was unusable.
+ if uint64(src.NumBytes()) != srcs.NumBytes() {
+ return 0, nil
+ }
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+ if srcs.NumBlocks() == 1 {
+ // Skip allocating []syscall.Iovec.
+ n, err := syscall.Write(s.fd, srcs.Head().ToSlice())
+ if err != nil {
+ return 0, translateIOSyscallError(err)
+ }
+ return uint64(n), nil
+ }
+ return writev(s.fd, safemem.IovecsFromBlockSeq(srcs))
+ }))
+ return int64(n), err
+}
+
+// Connect implements socket.Socket.Connect.
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ if len(sockaddr) > sizeofSockaddr {
+ sockaddr = sockaddr[:sizeofSockaddr]
+ }
+
+ _, _, errno := syscall.Syscall(syscall.SYS_CONNECT, uintptr(s.fd), uintptr(firstBytePtr(sockaddr)), uintptr(len(sockaddr)))
+
+ if errno == 0 {
+ return nil
+ }
+ if errno != syscall.EINPROGRESS || !blocking {
+ return syserr.FromError(translateIOSyscallError(errno))
+ }
+
+ // "EINPROGRESS: The socket is nonblocking and the connection cannot be
+ // completed immediately. It is possible to select(2) or poll(2) for
+ // completion by selecting the socket for writing. After select(2)
+ // indicates writability, use getsockopt(2) to read the SO_ERROR option at
+ // level SOL-SOCKET to determine whether connect() completed successfully
+ // (SO_ERROR is zero) or unsuccessfully (SO_ERROR is one of the usual error
+ // codes listed here, explaining the reason for the failure)." - connect(2)
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+ if s.Readiness(waiter.EventOut)&waiter.EventOut == 0 {
+ if err := t.Block(ch); err != nil {
+ return syserr.FromError(err)
+ }
+ }
+ val, err := syscall.GetsockoptInt(s.fd, syscall.SOL_SOCKET, syscall.SO_ERROR)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+ if val != 0 {
+ return syserr.FromError(syscall.Errno(uintptr(val)))
+ }
+ return nil
+}
+
+// Accept implements socket.Socket.Accept.
+func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ var peerAddr linux.SockAddr
+ var peerAddrBuf []byte
+ var peerAddrlen uint32
+ var peerAddrPtr *byte
+ var peerAddrlenPtr *uint32
+ if peerRequested {
+ peerAddrBuf = make([]byte, sizeofSockaddr)
+ peerAddrlen = uint32(len(peerAddrBuf))
+ peerAddrPtr = &peerAddrBuf[0]
+ peerAddrlenPtr = &peerAddrlen
+ }
+
+ // Conservatively ignore all flags specified by the application and add
+ // SOCK_NONBLOCK since socketOpsCommon requires it.
+ fd, syscallErr := accept4(s.fd, peerAddrPtr, peerAddrlenPtr, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC)
+ if blocking {
+ var ch chan struct{}
+ for syscallErr == syserror.ErrWouldBlock {
+ if ch != nil {
+ if syscallErr = t.Block(ch); syscallErr != nil {
+ break
+ }
+ } else {
+ var e waiter.Entry
+ e, ch = waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+ }
+ fd, syscallErr = accept4(s.fd, peerAddrPtr, peerAddrlenPtr, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC)
+ }
+ }
+
+ if peerRequested {
+ peerAddr = socket.UnmarshalSockAddr(s.family, peerAddrBuf[:peerAddrlen])
+ }
+ if syscallErr != nil {
+ return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr)
+ }
+
+ var (
+ kfd int32
+ kerr error
+ )
+ if kernel.VFS2Enabled {
+ f, err := newVFS2Socket(t, s.family, s.stype, s.protocol, fd, uint32(flags&syscall.SOCK_NONBLOCK))
+ if err != nil {
+ syscall.Close(fd)
+ return 0, nil, 0, err
+ }
+ defer f.DecRef()
+
+ kfd, kerr = t.NewFDFromVFS2(0, f, kernel.FDFlags{
+ CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
+ })
+ t.Kernel().RecordSocketVFS2(f)
+ } else {
+ f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0)
+ if err != nil {
+ syscall.Close(fd)
+ return 0, nil, 0, err
+ }
+ defer f.DecRef()
+
+ kfd, kerr = t.NewFDFrom(0, f, kernel.FDFlags{
+ CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
+ })
+ t.Kernel().RecordSocket(f)
+ }
+
+ return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr)
+}
+
+// Bind implements socket.Socket.Bind.
+func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ if len(sockaddr) > sizeofSockaddr {
+ sockaddr = sockaddr[:sizeofSockaddr]
+ }
+
+ _, _, errno := syscall.Syscall(syscall.SYS_BIND, uintptr(s.fd), uintptr(firstBytePtr(sockaddr)), uintptr(len(sockaddr)))
+ if errno != 0 {
+ return syserr.FromError(errno)
+ }
+ return nil
+}
+
+// Listen implements socket.Socket.Listen.
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ return syserr.FromError(syscall.Listen(s.fd, backlog))
+}
+
+// Shutdown implements socket.Socket.Shutdown.
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ switch how {
+ case syscall.SHUT_RD, syscall.SHUT_WR, syscall.SHUT_RDWR:
+ return syserr.FromError(syscall.Shutdown(s.fd, how))
+ default:
+ return syserr.ErrInvalidArgument
+ }
+}
+
+// GetSockOpt implements socket.Socket.GetSockOpt.
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+ if outLen < 0 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Only allow known and safe options.
+ optlen := getSockOptLen(t, level, name)
+ switch level {
+ case linux.SOL_IP:
+ switch name {
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO:
+ optlen = sizeofInt32
+ }
+ case linux.SOL_IPV6:
+ switch name {
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ optlen = sizeofInt32
+ }
+ case linux.SOL_SOCKET:
+ switch name {
+ case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ optlen = sizeofInt32
+ case linux.SO_LINGER:
+ optlen = syscall.SizeofLinger
+ }
+ case linux.SOL_TCP:
+ switch name {
+ case linux.TCP_NODELAY:
+ optlen = sizeofInt32
+ case linux.TCP_INFO:
+ optlen = int(linux.SizeOfTCPInfo)
+ }
+ }
+
+ if optlen == 0 {
+ return nil, syserr.ErrProtocolNotAvailable // ENOPROTOOPT
+ }
+ if outLen < optlen {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ opt, err := getsockopt(s.fd, level, name, optlen)
+ if err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return opt, nil
+}
+
+// SetSockOpt implements socket.Socket.SetSockOpt.
+func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
+ // Only allow known and safe options.
+ optlen := setSockOptLen(t, level, name)
+ switch level {
+ case linux.SOL_IP:
+ switch name {
+ case linux.IP_TOS, linux.IP_RECVTOS:
+ optlen = sizeofInt32
+ case linux.IP_PKTINFO:
+ optlen = linux.SizeOfControlMessageIPPacketInfo
+ }
+ case linux.SOL_IPV6:
+ switch name {
+ case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY:
+ optlen = sizeofInt32
+ }
+ case linux.SOL_SOCKET:
+ switch name {
+ case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR:
+ optlen = sizeofInt32
+ }
+ case linux.SOL_TCP:
+ switch name {
+ case linux.TCP_NODELAY:
+ optlen = sizeofInt32
+ }
+ }
+
+ if optlen == 0 {
+ // Pretend to accept socket options we don't understand. This seems
+ // dangerous, but it's what netstack does...
+ return nil
+ }
+ if len(opt) < optlen {
+ return syserr.ErrInvalidArgument
+ }
+ opt = opt[:optlen]
+
+ _, _, errno := syscall.Syscall6(syscall.SYS_SETSOCKOPT, uintptr(s.fd), uintptr(level), uintptr(name), uintptr(firstBytePtr(opt)), uintptr(len(opt)), 0)
+ if errno != 0 {
+ return syserr.FromError(errno)
+ }
+ return nil
+}
+
+// RecvMsg implements socket.Socket.RecvMsg.
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ // Only allow known and safe flags.
+ //
+ // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary
+ // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the
+ // Socket interface's dependence on netstack.
+ if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument
+ }
+
+ var senderAddr linux.SockAddr
+ var senderAddrBuf []byte
+ if senderRequested {
+ senderAddrBuf = make([]byte, sizeofSockaddr)
+ }
+
+ var controlBuf []byte
+ var msgFlags int
+
+ recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of dst.Addrs was unusable.
+ if uint64(dst.NumBytes()) != dsts.NumBytes() {
+ return 0, nil
+ }
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+
+ // We always do a non-blocking recv*().
+ sysflags := flags | syscall.MSG_DONTWAIT
+
+ iovs := safemem.IovecsFromBlockSeq(dsts)
+ msg := syscall.Msghdr{
+ Iov: &iovs[0],
+ Iovlen: uint64(len(iovs)),
+ }
+ if len(senderAddrBuf) != 0 {
+ msg.Name = &senderAddrBuf[0]
+ msg.Namelen = uint32(len(senderAddrBuf))
+ }
+ if controlLen > 0 {
+ if controlLen > maxControlLen {
+ controlLen = maxControlLen
+ }
+ controlBuf = make([]byte, controlLen)
+ msg.Control = &controlBuf[0]
+ msg.Controllen = controlLen
+ }
+ n, err := recvmsg(s.fd, &msg, sysflags)
+ if err != nil {
+ return 0, err
+ }
+ senderAddrBuf = senderAddrBuf[:msg.Namelen]
+ msgFlags = int(msg.Flags)
+ controlLen = uint64(msg.Controllen)
+ return n, nil
+ })
+
+ var ch chan struct{}
+ n, err := dst.CopyOutFrom(t, recvmsgToBlocks)
+ if flags&syscall.MSG_DONTWAIT == 0 {
+ for err == syserror.ErrWouldBlock {
+ // We only expect blocking to come from the actual syscall, in which
+ // case it can't have returned any data.
+ if n != 0 {
+ panic(fmt.Sprintf("CopyOutFrom: got (%d, %v), wanted (0, %v)", n, err, err))
+ }
+ if ch != nil {
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ break
+ }
+ } else {
+ var e waiter.Entry
+ e, ch = waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+ }
+ n, err = dst.CopyOutFrom(t, recvmsgToBlocks)
+ }
+ }
+ if err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+
+ if senderRequested {
+ senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf)
+ }
+
+ unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen])
+ if err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+
+ controlMessages := socket.ControlMessages{}
+ for _, unixCmsg := range unixControlMessages {
+ switch unixCmsg.Header.Level {
+ case syscall.SOL_IP:
+ switch unixCmsg.Header.Type {
+ case syscall.IP_TOS:
+ controlMessages.IP.HasTOS = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
+
+ case syscall.IP_PKTINFO:
+ controlMessages.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+ controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo)
+ }
+
+ case syscall.SOL_IPV6:
+ switch unixCmsg.Header.Type {
+ case syscall.IPV6_TCLASS:
+ controlMessages.IP.HasTClass = true
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass)
+ }
+ }
+ }
+
+ return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil
+}
+
+// SendMsg implements socket.Socket.SendMsg.
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ // Only allow known and safe flags.
+ if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ space := uint64(control.CmsgsSpace(t, controlMessages))
+ if space > maxControlLen {
+ space = maxControlLen
+ }
+ controlBuf := make([]byte, 0, space)
+ // PackControlMessages will append up to space bytes to controlBuf.
+ controlBuf = control.PackControlMessages(t, controlMessages, controlBuf)
+
+ sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of src.Addrs was unusable.
+ if uint64(src.NumBytes()) != srcs.NumBytes() {
+ return 0, nil
+ }
+ if srcs.IsEmpty() && len(controlBuf) == 0 {
+ return 0, nil
+ }
+
+ // We always do a non-blocking send*().
+ sysflags := flags | syscall.MSG_DONTWAIT
+
+ if srcs.NumBlocks() == 1 && len(controlBuf) == 0 {
+ // Skip allocating []syscall.Iovec.
+ src := srcs.Head()
+ n, _, errno := syscall.Syscall6(syscall.SYS_SENDTO, uintptr(s.fd), src.Addr(), uintptr(src.Len()), uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to)))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+ }
+
+ iovs := safemem.IovecsFromBlockSeq(srcs)
+ msg := syscall.Msghdr{
+ Iov: &iovs[0],
+ Iovlen: uint64(len(iovs)),
+ }
+ if len(to) != 0 {
+ msg.Name = &to[0]
+ msg.Namelen = uint32(len(to))
+ }
+ if len(controlBuf) != 0 {
+ msg.Control = &controlBuf[0]
+ msg.Controllen = uint64(len(controlBuf))
+ }
+ return sendmsg(s.fd, &msg, sysflags)
+ })
+
+ var ch chan struct{}
+ n, err := src.CopyInTo(t, sendmsgFromBlocks)
+ if flags&syscall.MSG_DONTWAIT == 0 {
+ for err == syserror.ErrWouldBlock {
+ // We only expect blocking to come from the actual syscall, in which
+ // case it can't have returned any data.
+ if n != 0 {
+ panic(fmt.Sprintf("CopyInTo: got (%d, %v), wanted (0, %v)", n, err, err))
+ }
+ if ch != nil {
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ } else {
+ var e waiter.Entry
+ e, ch = waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+ }
+ n, err = src.CopyInTo(t, sendmsgFromBlocks)
+ }
+ }
+
+ return int(n), syserr.FromError(err)
+}
+
+func translateIOSyscallError(err error) error {
+ if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK {
+ return syserror.ErrWouldBlock
+ }
+ return err
+}
+
+// State implements socket.Socket.State.
+func (s *socketOpsCommon) State() uint32 {
+ info := linux.TCPInfo{}
+ buf, err := getsockopt(s.fd, syscall.SOL_TCP, syscall.TCP_INFO, linux.SizeOfTCPInfo)
+ if err != nil {
+ if err != syscall.ENOPROTOOPT {
+ log.Warningf("Failed to get TCP socket info from %+v: %v", s, err)
+ }
+ // For non-TCP sockets, silently ignore the failure.
+ return 0
+ }
+ if len(buf) != linux.SizeOfTCPInfo {
+ // Unmarshal below will panic if getsockopt returns a buffer of
+ // unexpected size.
+ log.Warningf("Failed to get TCP socket info from %+v: getsockopt(2) returned %d bytes, expecting %d bytes.", s, len(buf), linux.SizeOfTCPInfo)
+ return 0
+ }
+
+ binary.Unmarshal(buf, usermem.ByteOrder, &info)
+ return uint32(info.State)
+}
+
+// Type implements socket.Socket.Type.
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
+ return s.family, s.stype, s.protocol
+}
+
+type socketProvider struct {
+ family int
+}
+
+// Socket implements socket.Provider.Socket.
+func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Check that we are using the host network stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ return nil, nil
+ }
+ if _, ok := stack.(*Stack); !ok {
+ return nil, nil
+ }
+
+ // Only accept TCP and UDP.
+ stype := stypeflags & linux.SOCK_TYPE_MASK
+ switch stype {
+ case syscall.SOCK_STREAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_TCP:
+ // ok
+ default:
+ return nil, nil
+ }
+ case syscall.SOCK_DGRAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_UDP:
+ // ok
+ default:
+ return nil, nil
+ }
+ default:
+ return nil, nil
+ }
+
+ // Conservatively ignore all flags specified by the application and add
+ // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0
+ // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent.
+ fd, err := syscall.Socket(p.family, int(stype)|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return newSocketFile(t, p.family, stype, protocol, fd, stypeflags&syscall.SOCK_NONBLOCK != 0)
+}
+
+// Pair implements socket.Provider.Pair.
+func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+ // Not supported by AF_INET/AF_INET6.
+ return nil, nil, nil
+}
+
+// LINT.ThenChange(./socket_vfs2.go)
+
+func init() {
+ for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
+ socket.RegisterProvider(family, &socketProvider{family})
+ socket.RegisterProviderVFS2(family, &socketProviderVFS2{})
+ }
+}
diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go
new file mode 100644
index 000000000..3f420c2ec
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/socket_unsafe.go
@@ -0,0 +1,139 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+func firstBytePtr(bs []byte) unsafe.Pointer {
+ if bs == nil {
+ return nil
+ }
+ return unsafe.Pointer(&bs[0])
+}
+
+// Preconditions: len(dsts) != 0.
+func readv(fd int, dsts []syscall.Iovec) (uint64, error) {
+ n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&dsts[0])), uintptr(len(dsts)))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+}
+
+// Preconditions: len(srcs) != 0.
+func writev(fd int, srcs []syscall.Iovec) (uint64, error) {
+ n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&srcs[0])), uintptr(len(srcs)))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+}
+
+func ioctl(ctx context.Context, fd int, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch cmd := uintptr(args[1].Int()); cmd {
+ case syscall.TIOCINQ, syscall.TIOCOUTQ:
+ var val int32
+ if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), cmd, uintptr(unsafe.Pointer(&val))); errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ var buf [4]byte
+ usermem.ByteOrder.PutUint32(buf[:], uint32(val))
+ _, err := io.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+func accept4(fd int, addr *byte, addrlen *uint32, flags int) (int, error) {
+ afd, _, errno := syscall.Syscall6(syscall.SYS_ACCEPT4, uintptr(fd), uintptr(unsafe.Pointer(addr)), uintptr(unsafe.Pointer(addrlen)), uintptr(flags), 0, 0)
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return int(afd), nil
+}
+
+func getsockopt(fd int, level, name int, optlen int) ([]byte, error) {
+ opt := make([]byte, optlen)
+ optlen32 := int32(len(opt))
+ _, _, errno := syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd), uintptr(level), uintptr(name), uintptr(firstBytePtr(opt)), uintptr(unsafe.Pointer(&optlen32)), 0)
+ if errno != 0 {
+ return nil, errno
+ }
+ return opt[:optlen32], nil
+}
+
+// GetSockName implements socket.Socket.GetSockName.
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ addr := make([]byte, sizeofSockaddr)
+ addrlen := uint32(len(addr))
+ _, _, errno := syscall.Syscall(syscall.SYS_GETSOCKNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen)))
+ if errno != 0 {
+ return nil, 0, syserr.FromError(errno)
+ }
+ return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil
+}
+
+// GetPeerName implements socket.Socket.GetPeerName.
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ addr := make([]byte, sizeofSockaddr)
+ addrlen := uint32(len(addr))
+ _, _, errno := syscall.Syscall(syscall.SYS_GETPEERNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen)))
+ if errno != 0 {
+ return nil, 0, syserr.FromError(errno)
+ }
+ return socket.UnmarshalSockAddr(s.family, addr), addrlen, nil
+}
+
+func recvfrom(fd int, dst []byte, flags int, from *[]byte) (uint64, error) {
+ fromLen := uint32(len(*from))
+ n, _, errno := syscall.Syscall6(syscall.SYS_RECVFROM, uintptr(fd), uintptr(firstBytePtr(dst)), uintptr(len(dst)), uintptr(flags), uintptr(firstBytePtr(*from)), uintptr(unsafe.Pointer(&fromLen)))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ *from = (*from)[:fromLen]
+ return uint64(n), nil
+}
+
+func recvmsg(fd int, msg *syscall.Msghdr, flags int) (uint64, error) {
+ n, _, errno := syscall.Syscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(msg)), uintptr(flags))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+}
+
+func sendmsg(fd int, msg *syscall.Msghdr, flags int) (uint64, error) {
+ n, _, errno := syscall.Syscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(msg)), uintptr(flags))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+}
diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go
new file mode 100644
index 000000000..8f192c62f
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/socket_vfs2.go
@@ -0,0 +1,202 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fdnotifier"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/hostfd"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+type socketVFS2 struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.LockFD
+
+ // We store metadata for hostinet sockets internally. Technically, we should
+ // access metadata (e.g. through stat, chmod) on the host for correctness,
+ // but this is not very useful for inet socket fds, which do not belong to a
+ // concrete file anyway.
+ vfs.DentryMetadataFileDescriptionImpl
+
+ socketOpsCommon
+}
+
+var _ = socket.SocketVFS2(&socketVFS2{})
+
+func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) {
+ mnt := t.Kernel().SocketMount()
+ d := sockfs.NewDentry(t.Credentials(), mnt)
+
+ s := &socketVFS2{
+ socketOpsCommon: socketOpsCommon{
+ family: family,
+ stype: stype,
+ protocol: protocol,
+ fd: fd,
+ },
+ }
+ s.LockFD.Init(&vfs.FileLocks{})
+ if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ vfsfd := &s.vfsfd
+ if err := vfsfd.Init(s, linux.O_RDWR|(flags&linux.O_NONBLOCK), mnt, d, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return vfsfd, nil
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *socketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.socketOpsCommon.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *socketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.socketOpsCommon.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *socketVFS2) EventUnregister(e *waiter.Entry) {
+ s.socketOpsCommon.EventUnregister(e)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (s *socketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return ioctl(ctx, s.fd, uio, args)
+}
+
+// Allocate implements vfs.FileDescriptionImpl.Allocate.
+func (s *socketVFS2) Allocate(ctx context.Context, mode, offset, length uint64) error {
+ return syserror.ENODEV
+}
+
+// PRead implements vfs.FileDescriptionImpl.PRead.
+func (s *socketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Read implements vfs.FileDescriptionImpl.
+func (s *socketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ reader := hostfd.GetReadWriterAt(int32(s.fd), -1, opts.Flags)
+ n, err := dst.CopyOutFrom(ctx, reader)
+ hostfd.PutReadWriterAt(reader)
+ return int64(n), err
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (s *socketVFS2) PWrite(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (s *socketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ writer := hostfd.GetReadWriterAt(int32(s.fd), -1, opts.Flags)
+ n, err := src.CopyInTo(ctx, writer)
+ hostfd.PutReadWriterAt(writer)
+ return int64(n), err
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (s *socketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (s *socketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
+}
+
+type socketProviderVFS2 struct {
+ family int
+}
+
+// Socket implements socket.ProviderVFS2.Socket.
+func (p *socketProviderVFS2) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Check that we are using the host network stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ return nil, nil
+ }
+ if _, ok := stack.(*Stack); !ok {
+ return nil, nil
+ }
+
+ // Only accept TCP and UDP.
+ stype := stypeflags & linux.SOCK_TYPE_MASK
+ switch stype {
+ case syscall.SOCK_STREAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_TCP:
+ // ok
+ default:
+ return nil, nil
+ }
+ case syscall.SOCK_DGRAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_UDP:
+ // ok
+ default:
+ return nil, nil
+ }
+ default:
+ return nil, nil
+ }
+
+ // Conservatively ignore all flags specified by the application and add
+ // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0
+ // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent.
+ fd, err := syscall.Socket(p.family, int(stype)|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0)
+ if err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return newVFS2Socket(t, p.family, stype, protocol, fd, uint32(stypeflags&syscall.SOCK_NONBLOCK))
+}
+
+// Pair implements socket.Provider.Pair.
+func (p *socketProviderVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ // Not supported by AF_INET/AF_INET6.
+ return nil, nil, nil
+}
diff --git a/pkg/sentry/socket/hostinet/sockopt_impl.go b/pkg/sentry/socket/hostinet/sockopt_impl.go
new file mode 100644
index 000000000..8a783712e
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/sockopt_impl.go
@@ -0,0 +1,27 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+func getSockOptLen(t *kernel.Task, level, name int) int {
+ return 0 // No custom options.
+}
+
+func setSockOptLen(t *kernel.Task, level, name int) int {
+ return 0 // No custom options.
+}
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
new file mode 100644
index 000000000..a48082631
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -0,0 +1,459 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "fmt"
+ "io"
+ "io/ioutil"
+ "os"
+ "reflect"
+ "strconv"
+ "strings"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+var defaultRecvBufSize = inet.TCPBufferSize{
+ Min: 4096,
+ Default: 87380,
+ Max: 6291456,
+}
+
+var defaultSendBufSize = inet.TCPBufferSize{
+ Min: 4096,
+ Default: 16384,
+ Max: 4194304,
+}
+
+// Stack implements inet.Stack for host sockets.
+type Stack struct {
+ // Stack is immutable.
+ interfaces map[int32]inet.Interface
+ interfaceAddrs map[int32][]inet.InterfaceAddr
+ routes []inet.Route
+ supportsIPv6 bool
+ tcpRecvBufSize inet.TCPBufferSize
+ tcpSendBufSize inet.TCPBufferSize
+ tcpSACKEnabled bool
+ netDevFile *os.File
+ netSNMPFile *os.File
+}
+
+// NewStack returns an empty Stack containing no configuration.
+func NewStack() *Stack {
+ return &Stack{
+ interfaces: make(map[int32]inet.Interface),
+ interfaceAddrs: make(map[int32][]inet.InterfaceAddr),
+ }
+}
+
+// Configure sets up the stack using the current state of the host network.
+func (s *Stack) Configure() error {
+ if err := addHostInterfaces(s); err != nil {
+ return err
+ }
+
+ if err := addHostRoutes(s); err != nil {
+ return err
+ }
+
+ if _, err := os.Stat("/proc/net/if_inet6"); err == nil {
+ s.supportsIPv6 = true
+ }
+
+ s.tcpRecvBufSize = defaultRecvBufSize
+ if tcpRMem, err := readTCPBufferSizeFile("/proc/sys/net/ipv4/tcp_rmem"); err == nil {
+ s.tcpRecvBufSize = tcpRMem
+ } else {
+ log.Warningf("Failed to read TCP receive buffer size, using default values")
+ }
+
+ s.tcpSendBufSize = defaultSendBufSize
+ if tcpWMem, err := readTCPBufferSizeFile("/proc/sys/net/ipv4/tcp_wmem"); err == nil {
+ s.tcpSendBufSize = tcpWMem
+ } else {
+ log.Warningf("Failed to read TCP send buffer size, using default values")
+ }
+
+ // SACK is important for performance and even compatibility, assume it's
+ // enabled if we can't find the actual value.
+ s.tcpSACKEnabled = true
+ if sack, err := ioutil.ReadFile("/proc/sys/net/ipv4/tcp_sack"); err == nil {
+ s.tcpSACKEnabled = strings.TrimSpace(string(sack)) != "0"
+ } else {
+ log.Warningf("Failed to read if TCP SACK if enabled, setting to true")
+ }
+
+ if f, err := os.Open("/proc/net/dev"); err != nil {
+ log.Warningf("Failed to open /proc/net/dev: %v", err)
+ } else {
+ s.netDevFile = f
+ }
+
+ if f, err := os.Open("/proc/net/snmp"); err != nil {
+ log.Warningf("Failed to open /proc/net/snmp: %v", err)
+ } else {
+ s.netSNMPFile = f
+ }
+
+ return nil
+}
+
+// ExtractHostInterfaces will populate an interface map and
+// interfaceAddrs map with the results of the equivalent
+// netlink messages.
+func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.NetlinkMessage, interfaces map[int32]inet.Interface, interfaceAddrs map[int32][]inet.InterfaceAddr) error {
+ for _, link := range links {
+ if link.Header.Type != syscall.RTM_NEWLINK {
+ continue
+ }
+ if len(link.Data) < syscall.SizeofIfInfomsg {
+ return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), syscall.SizeofIfInfomsg)
+ }
+ var ifinfo syscall.IfInfomsg
+ binary.Unmarshal(link.Data[:syscall.SizeofIfInfomsg], usermem.ByteOrder, &ifinfo)
+ inetIF := inet.Interface{
+ DeviceType: ifinfo.Type,
+ Flags: ifinfo.Flags,
+ }
+ // Not clearly documented: syscall.ParseNetlinkRouteAttr will check the
+ // syscall.NetlinkMessage.Header.Type and skip the struct ifinfomsg
+ // accordingly.
+ attrs, err := syscall.ParseNetlinkRouteAttr(&link)
+ if err != nil {
+ return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid rtattrs: %v", err)
+ }
+ for _, attr := range attrs {
+ switch attr.Attr.Type {
+ case syscall.IFLA_ADDRESS:
+ inetIF.Addr = attr.Value
+ case syscall.IFLA_IFNAME:
+ inetIF.Name = string(attr.Value[:len(attr.Value)-1])
+ }
+ }
+ interfaces[ifinfo.Index] = inetIF
+ }
+
+ for _, addr := range addrs {
+ if addr.Header.Type != syscall.RTM_NEWADDR {
+ continue
+ }
+ if len(addr.Data) < syscall.SizeofIfAddrmsg {
+ return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), syscall.SizeofIfAddrmsg)
+ }
+ var ifaddr syscall.IfAddrmsg
+ binary.Unmarshal(addr.Data[:syscall.SizeofIfAddrmsg], usermem.ByteOrder, &ifaddr)
+ inetAddr := inet.InterfaceAddr{
+ Family: ifaddr.Family,
+ PrefixLen: ifaddr.Prefixlen,
+ Flags: ifaddr.Flags,
+ }
+ attrs, err := syscall.ParseNetlinkRouteAttr(&addr)
+ if err != nil {
+ return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid rtattrs: %v", err)
+ }
+ for _, attr := range attrs {
+ switch attr.Attr.Type {
+ case syscall.IFA_ADDRESS:
+ inetAddr.Addr = attr.Value
+ }
+ }
+ interfaceAddrs[int32(ifaddr.Index)] = append(interfaceAddrs[int32(ifaddr.Index)], inetAddr)
+ }
+
+ return nil
+}
+
+// ExtractHostRoutes populates the given routes slice with the data from the
+// host route table.
+func ExtractHostRoutes(routeMsgs []syscall.NetlinkMessage) ([]inet.Route, error) {
+ var routes []inet.Route
+ for _, routeMsg := range routeMsgs {
+ if routeMsg.Header.Type != syscall.RTM_NEWROUTE {
+ continue
+ }
+
+ var ifRoute syscall.RtMsg
+ binary.Unmarshal(routeMsg.Data[:syscall.SizeofRtMsg], usermem.ByteOrder, &ifRoute)
+ inetRoute := inet.Route{
+ Family: ifRoute.Family,
+ DstLen: ifRoute.Dst_len,
+ SrcLen: ifRoute.Src_len,
+ TOS: ifRoute.Tos,
+ Table: ifRoute.Table,
+ Protocol: ifRoute.Protocol,
+ Scope: ifRoute.Scope,
+ Type: ifRoute.Type,
+ Flags: ifRoute.Flags,
+ }
+
+ // Not clearly documented: syscall.ParseNetlinkRouteAttr will check the
+ // syscall.NetlinkMessage.Header.Type and skip the struct rtmsg
+ // accordingly.
+ attrs, err := syscall.ParseNetlinkRouteAttr(&routeMsg)
+ if err != nil {
+ return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid rtattrs: %v", err)
+ }
+
+ for _, attr := range attrs {
+ switch attr.Attr.Type {
+ case syscall.RTA_DST:
+ inetRoute.DstAddr = attr.Value
+ case syscall.RTA_SRC:
+ inetRoute.SrcAddr = attr.Value
+ case syscall.RTA_GATEWAY:
+ inetRoute.GatewayAddr = attr.Value
+ case syscall.RTA_OIF:
+ expected := int(binary.Size(inetRoute.OutputInterface))
+ if len(attr.Value) != expected {
+ return nil, fmt.Errorf("RTM_GETROUTE returned RTM_NEWROUTE message with invalid attribute data length (%d bytes, expected %d bytes)", len(attr.Value), expected)
+ }
+ binary.Unmarshal(attr.Value, usermem.ByteOrder, &inetRoute.OutputInterface)
+ }
+ }
+
+ routes = append(routes, inetRoute)
+ }
+
+ return routes, nil
+}
+
+func addHostInterfaces(s *Stack) error {
+ links, err := doNetlinkRouteRequest(syscall.RTM_GETLINK)
+ if err != nil {
+ return fmt.Errorf("RTM_GETLINK failed: %v", err)
+ }
+
+ addrs, err := doNetlinkRouteRequest(syscall.RTM_GETADDR)
+ if err != nil {
+ return fmt.Errorf("RTM_GETADDR failed: %v", err)
+ }
+
+ return ExtractHostInterfaces(links, addrs, s.interfaces, s.interfaceAddrs)
+}
+
+func addHostRoutes(s *Stack) error {
+ routes, err := doNetlinkRouteRequest(syscall.RTM_GETROUTE)
+ if err != nil {
+ return fmt.Errorf("RTM_GETROUTE failed: %v", err)
+ }
+
+ s.routes, err = ExtractHostRoutes(routes)
+ if err != nil {
+ return err
+ }
+
+ return nil
+}
+
+func doNetlinkRouteRequest(req int) ([]syscall.NetlinkMessage, error) {
+ data, err := syscall.NetlinkRIB(req, syscall.AF_UNSPEC)
+ if err != nil {
+ return nil, err
+ }
+ return syscall.ParseNetlinkMessage(data)
+}
+
+func readTCPBufferSizeFile(filename string) (inet.TCPBufferSize, error) {
+ contents, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return inet.TCPBufferSize{}, fmt.Errorf("failed to read %s: %v", filename, err)
+ }
+ ioseq := usermem.BytesIOSequence(contents)
+ fields := make([]int32, 3)
+ if n, err := usermem.CopyInt32StringsInVec(context.Background(), ioseq.IO, ioseq.Addrs, fields, ioseq.Opts); n != ioseq.NumBytes() || err != nil {
+ return inet.TCPBufferSize{}, fmt.Errorf("failed to parse %s (%q): got %v after %d/%d bytes", filename, contents, err, n, ioseq.NumBytes())
+ }
+ return inet.TCPBufferSize{
+ Min: int(fields[0]),
+ Default: int(fields[1]),
+ Max: int(fields[2]),
+ }, nil
+}
+
+// Interfaces implements inet.Stack.Interfaces.
+func (s *Stack) Interfaces() map[int32]inet.Interface {
+ interfaces := make(map[int32]inet.Interface)
+ for k, v := range s.interfaces {
+ interfaces[k] = v
+ }
+ return interfaces
+}
+
+// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
+func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
+ addrs := make(map[int32][]inet.InterfaceAddr)
+ for k, v := range s.interfaceAddrs {
+ addrs[k] = append([]inet.InterfaceAddr(nil), v...)
+ }
+ return addrs
+}
+
+// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
+func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+ return syserror.EACCES
+}
+
+// SupportsIPv6 implements inet.Stack.SupportsIPv6.
+func (s *Stack) SupportsIPv6() bool {
+ return s.supportsIPv6
+}
+
+// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
+func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
+ return s.tcpRecvBufSize, nil
+}
+
+// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
+func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
+ return syserror.EACCES
+}
+
+// TCPSendBufferSize implements inet.Stack.TCPSendBufferSize.
+func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
+ return s.tcpSendBufSize, nil
+}
+
+// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
+func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
+ return syserror.EACCES
+}
+
+// TCPSACKEnabled implements inet.Stack.TCPSACKEnabled.
+func (s *Stack) TCPSACKEnabled() (bool, error) {
+ return s.tcpSACKEnabled, nil
+}
+
+// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
+func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
+ return syserror.EACCES
+}
+
+// getLine reads one line from proc file, with specified prefix.
+// The last argument, withHeader, specifies if it contains line header.
+func getLine(f *os.File, prefix string, withHeader bool) string {
+ data := make([]byte, 4096)
+
+ if _, err := f.Seek(0, 0); err != nil {
+ return ""
+ }
+
+ if _, err := io.ReadFull(f, data); err != io.ErrUnexpectedEOF {
+ return ""
+ }
+
+ prefix = prefix + ":"
+ lines := strings.Split(string(data), "\n")
+ for _, l := range lines {
+ l = strings.TrimSpace(l)
+ if strings.HasPrefix(l, prefix) {
+ if withHeader {
+ withHeader = false
+ continue
+ }
+ return l
+ }
+ }
+ return ""
+}
+
+func toSlice(i interface{}) []uint64 {
+ v := reflect.Indirect(reflect.ValueOf(i))
+ return v.Slice(0, v.Len()).Interface().([]uint64)
+}
+
+// Statistics implements inet.Stack.Statistics.
+func (s *Stack) Statistics(stat interface{}, arg string) error {
+ var (
+ snmpTCP bool
+ rawLine string
+ sliceStat []uint64
+ )
+
+ switch stat.(type) {
+ case *inet.StatDev:
+ if s.netDevFile == nil {
+ return fmt.Errorf("/proc/net/dev is not opened for hostinet")
+ }
+ rawLine = getLine(s.netDevFile, arg, false /* with no header */)
+ case *inet.StatSNMPIP, *inet.StatSNMPICMP, *inet.StatSNMPICMPMSG, *inet.StatSNMPTCP, *inet.StatSNMPUDP, *inet.StatSNMPUDPLite:
+ if s.netSNMPFile == nil {
+ return fmt.Errorf("/proc/net/snmp is not opened for hostinet")
+ }
+ rawLine = getLine(s.netSNMPFile, arg, true)
+ default:
+ return syserr.ErrEndpointOperation.ToError()
+ }
+
+ if rawLine == "" {
+ return fmt.Errorf("Failed to get raw line")
+ }
+
+ parts := strings.SplitN(rawLine, ":", 2)
+ if len(parts) != 2 {
+ return fmt.Errorf("Failed to get prefix from: %q", rawLine)
+ }
+
+ sliceStat = toSlice(stat)
+ fields := strings.Fields(strings.TrimSpace(parts[1]))
+ if len(fields) != len(sliceStat) {
+ return fmt.Errorf("Failed to parse fields: %q", rawLine)
+ }
+ if _, ok := stat.(*inet.StatSNMPTCP); ok {
+ snmpTCP = true
+ }
+ for i := 0; i < len(sliceStat); i++ {
+ var err error
+ if snmpTCP && i == 3 {
+ var tmp int64
+ // MaxConn field is signed, RFC 2012.
+ tmp, err = strconv.ParseInt(fields[i], 10, 64)
+ sliceStat[i] = uint64(tmp) // Convert back to int before use.
+ } else {
+ sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64)
+ }
+ if err != nil {
+ return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err)
+ }
+ }
+
+ return nil
+}
+
+// RouteTable implements inet.Stack.RouteTable.
+func (s *Stack) RouteTable() []inet.Route {
+ return append([]inet.Route(nil), s.routes...)
+}
+
+// Resume implements inet.Stack.Resume.
+func (s *Stack) Resume() {}
+
+// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil }
+
+// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil }
+
+// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
new file mode 100644
index 000000000..721094bbf
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -0,0 +1,29 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "netfilter",
+ srcs = [
+ "extensions.go",
+ "netfilter.go",
+ "owner_matcher.go",
+ "targets.go",
+ "tcp_matcher.go",
+ "udp_matcher.go",
+ ],
+ # This target depends on netstack and should only be used by epsocket,
+ # which is allowed to depend on netstack.
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/log",
+ "//pkg/sentry/kernel",
+ "//pkg/syserr",
+ "//pkg/tcpip",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/stack",
+ "//pkg/usermem",
+ ],
+)
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
new file mode 100644
index 000000000..0336a32d8
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -0,0 +1,95 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// TODO(gvisor.dev/issue/170): The following per-matcher params should be
+// supported:
+// - Table name
+// - Match size
+// - User size
+// - Hooks
+// - Proto
+// - Family
+
+// matchMaker knows how to (un)marshal the matcher named name().
+type matchMaker interface {
+ // name is the matcher name as stored in the xt_entry_match struct.
+ name() string
+
+ // marshal converts from an stack.Matcher to an ABI struct.
+ marshal(matcher stack.Matcher) []byte
+
+ // unmarshal converts from the ABI matcher struct to an
+ // stack.Matcher.
+ unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error)
+}
+
+// matchMakers maps the name of supported matchers to the matchMaker that
+// marshals and unmarshals it. It is immutable after package initialization.
+var matchMakers = map[string]matchMaker{}
+
+// registermatchMaker should be called by match extensions to register them
+// with the netfilter package.
+func registerMatchMaker(mm matchMaker) {
+ if _, ok := matchMakers[mm.name()]; ok {
+ panic(fmt.Sprintf("Multiple matches registered with name %q.", mm.name()))
+ }
+ matchMakers[mm.name()] = mm
+}
+
+func marshalMatcher(matcher stack.Matcher) []byte {
+ matchMaker, ok := matchMakers[matcher.Name()]
+ if !ok {
+ panic(fmt.Sprintf("Unknown matcher of type %T.", matcher))
+ }
+ return matchMaker.marshal(matcher)
+}
+
+// marshalEntryMatch creates a marshalled XTEntryMatch with the given name and
+// data appended at the end.
+func marshalEntryMatch(name string, data []byte) []byte {
+ nflog("marshaling matcher %q", name)
+
+ // We have to pad this struct size to a multiple of 8 bytes.
+ size := binary.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8)
+ matcher := linux.KernelXTEntryMatch{
+ XTEntryMatch: linux.XTEntryMatch{
+ MatchSize: uint16(size),
+ },
+ Data: data,
+ }
+ copy(matcher.Name[:], name)
+
+ buf := make([]byte, 0, size)
+ buf = binary.Marshal(buf, usermem.ByteOrder, matcher)
+ return append(buf, make([]byte, size-len(buf))...)
+}
+
+func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) {
+ matchMaker, ok := matchMakers[match.Name.String()]
+ if !ok {
+ return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String())
+ }
+ return matchMaker.unmarshal(buf, filter)
+}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
new file mode 100644
index 000000000..f7abe77d3
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -0,0 +1,761 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package netfilter helps the sentry interact with netstack's netfilter
+// capabilities.
+package netfilter
+
+import (
+ "bytes"
+ "errors"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// errorTargetName is used to mark targets as error targets. Error targets
+// shouldn't be reached - an error has occurred if we fall through to one.
+const errorTargetName = "ERROR"
+
+// redirectTargetName is used to mark targets as redirect targets. Redirect
+// targets should be reached for only NAT and Mangle tables. These targets will
+// change the destination port/destination IP for packets.
+const redirectTargetName = "REDIRECT"
+
+// enableLogging controls whether to log the (de)serialization of netfilter
+// structs between userspace and netstack. These logs are useful when
+// developing iptables, but can pollute sentry logs otherwise.
+const enableLogging = false
+
+// emptyFilter is for comparison with a rule's filters to determine whether it
+// is also empty. It is immutable.
+var emptyFilter = stack.IPHeaderFilter{
+ Dst: "\x00\x00\x00\x00",
+ DstMask: "\x00\x00\x00\x00",
+ Src: "\x00\x00\x00\x00",
+ SrcMask: "\x00\x00\x00\x00",
+}
+
+// nflog logs messages related to the writing and reading of iptables.
+func nflog(format string, args ...interface{}) {
+ if enableLogging && log.IsLogging(log.Debug) {
+ log.Debugf("netfilter: "+format, args...)
+ }
+}
+
+// GetInfo returns information about iptables.
+func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) {
+ // Read in the struct and table name.
+ var info linux.IPTGetinfo
+ if _, err := t.CopyIn(outPtr, &info); err != nil {
+ return linux.IPTGetinfo{}, syserr.FromError(err)
+ }
+
+ _, info, err := convertNetstackToBinary(stack, info.Name)
+ if err != nil {
+ nflog("couldn't convert iptables: %v", err)
+ return linux.IPTGetinfo{}, syserr.ErrInvalidArgument
+ }
+
+ nflog("returning info: %+v", info)
+ return info, nil
+}
+
+// GetEntries returns netstack's iptables rules encoded for the iptables tool.
+func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) {
+ // Read in the struct and table name.
+ var userEntries linux.IPTGetEntries
+ if _, err := t.CopyIn(outPtr, &userEntries); err != nil {
+ nflog("couldn't copy in entries %q", userEntries.Name)
+ return linux.KernelIPTGetEntries{}, syserr.FromError(err)
+ }
+
+ // Convert netstack's iptables rules to something that the iptables
+ // tool can understand.
+ entries, _, err := convertNetstackToBinary(stack, userEntries.Name)
+ if err != nil {
+ nflog("couldn't read entries: %v", err)
+ return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
+ }
+ if binary.Size(entries) > uintptr(outLen) {
+ nflog("insufficient GetEntries output size: %d", uintptr(outLen))
+ return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
+ }
+
+ return entries, nil
+}
+
+// convertNetstackToBinary converts the iptables as stored in netstack to the
+// format expected by the iptables tool. Linux stores each table as a binary
+// blob that can only be traversed by parsing a bit, reading some offsets,
+// jumping to those offsets, parsing again, etc.
+func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) {
+ table, ok := stack.IPTables().GetTable(tablename.String())
+ if !ok {
+ return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
+ }
+
+ var entries linux.KernelIPTGetEntries
+ var info linux.IPTGetinfo
+ info.ValidHooks = table.ValidHooks()
+
+ // The table name has to fit in the struct.
+ if linux.XT_TABLE_MAXNAMELEN < len(tablename) {
+ return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
+ }
+ copy(info.Name[:], tablename[:])
+ copy(entries.Name[:], tablename[:])
+
+ for ruleIdx, rule := range table.Rules {
+ nflog("convert to binary: current offset: %d", entries.Size)
+
+ // Is this a chain entry point?
+ for hook, hookRuleIdx := range table.BuiltinChains {
+ if hookRuleIdx == ruleIdx {
+ nflog("convert to binary: found hook %d at offset %d", hook, entries.Size)
+ info.HookEntry[hook] = entries.Size
+ }
+ }
+ // Is this a chain underflow point?
+ for underflow, underflowRuleIdx := range table.Underflows {
+ if underflowRuleIdx == ruleIdx {
+ nflog("convert to binary: found underflow %d at offset %d", underflow, entries.Size)
+ info.Underflow[underflow] = entries.Size
+ }
+ }
+
+ // Each rule corresponds to an entry.
+ entry := linux.KernelIPTEntry{
+ IPTEntry: linux.IPTEntry{
+ IP: linux.IPTIP{
+ Protocol: uint16(rule.Filter.Protocol),
+ },
+ NextOffset: linux.SizeOfIPTEntry,
+ TargetOffset: linux.SizeOfIPTEntry,
+ },
+ }
+ copy(entry.IPTEntry.IP.Dst[:], rule.Filter.Dst)
+ copy(entry.IPTEntry.IP.DstMask[:], rule.Filter.DstMask)
+ copy(entry.IPTEntry.IP.Src[:], rule.Filter.Src)
+ copy(entry.IPTEntry.IP.SrcMask[:], rule.Filter.SrcMask)
+ copy(entry.IPTEntry.IP.OutputInterface[:], rule.Filter.OutputInterface)
+ copy(entry.IPTEntry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
+ if rule.Filter.DstInvert {
+ entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_DSTIP
+ }
+ if rule.Filter.SrcInvert {
+ entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_SRCIP
+ }
+ if rule.Filter.OutputInterfaceInvert {
+ entry.IPTEntry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT
+ }
+
+ for _, matcher := range rule.Matchers {
+ // Serialize the matcher and add it to the
+ // entry.
+ serialized := marshalMatcher(matcher)
+ nflog("convert to binary: matcher serialized as: %v", serialized)
+ if len(serialized)%8 != 0 {
+ panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher))
+ }
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.NextOffset += uint16(len(serialized))
+ entry.TargetOffset += uint16(len(serialized))
+ }
+
+ // Serialize and append the target.
+ serialized := marshalTarget(rule.Target)
+ if len(serialized)%8 != 0 {
+ panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target))
+ }
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.NextOffset += uint16(len(serialized))
+
+ nflog("convert to binary: adding entry: %+v", entry)
+
+ entries.Size += uint32(entry.NextOffset)
+ entries.Entrytable = append(entries.Entrytable, entry)
+ info.NumEntries++
+ }
+
+ nflog("convert to binary: finished with an marshalled size of %d", info.Size)
+ info.Size = entries.Size
+ return entries, info, nil
+}
+
+func marshalTarget(target stack.Target) []byte {
+ switch tg := target.(type) {
+ case stack.AcceptTarget:
+ return marshalStandardTarget(stack.RuleAccept)
+ case stack.DropTarget:
+ return marshalStandardTarget(stack.RuleDrop)
+ case stack.ErrorTarget:
+ return marshalErrorTarget(errorTargetName)
+ case stack.UserChainTarget:
+ return marshalErrorTarget(tg.Name)
+ case stack.ReturnTarget:
+ return marshalStandardTarget(stack.RuleReturn)
+ case stack.RedirectTarget:
+ return marshalRedirectTarget(tg)
+ case JumpTarget:
+ return marshalJumpTarget(tg)
+ default:
+ panic(fmt.Errorf("unknown target of type %T", target))
+ }
+}
+
+func marshalStandardTarget(verdict stack.RuleVerdict) []byte {
+ nflog("convert to binary: marshalling standard target")
+
+ // The target's name will be the empty string.
+ target := linux.XTStandardTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTStandardTarget,
+ },
+ Verdict: translateFromStandardVerdict(verdict),
+ }
+
+ ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+func marshalErrorTarget(errorName string) []byte {
+ // This is an error target named error
+ target := linux.XTErrorTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTErrorTarget,
+ },
+ }
+ copy(target.Name[:], errorName)
+ copy(target.Target.Name[:], errorTargetName)
+
+ ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+func marshalRedirectTarget(rt stack.RedirectTarget) []byte {
+ // This is a redirect target named redirect
+ target := linux.XTRedirectTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTRedirectTarget,
+ },
+ }
+ copy(target.Target.Name[:], redirectTargetName)
+
+ ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
+ target.NfRange.RangeSize = 1
+ if rt.RangeProtoSpecified {
+ target.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED
+ }
+ // Convert port from little endian to big endian.
+ port := make([]byte, 2)
+ binary.LittleEndian.PutUint16(port, rt.MinPort)
+ target.NfRange.RangeIPV4.MinPort = binary.BigEndian.Uint16(port)
+ binary.LittleEndian.PutUint16(port, rt.MaxPort)
+ target.NfRange.RangeIPV4.MaxPort = binary.BigEndian.Uint16(port)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+func marshalJumpTarget(jt JumpTarget) []byte {
+ nflog("convert to binary: marshalling jump target")
+
+ // The target's name will be the empty string.
+ target := linux.XTStandardTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTStandardTarget,
+ },
+ // Verdict is overloaded by the ABI. When positive, it holds
+ // the jump offset from the start of the table.
+ Verdict: int32(jt.Offset),
+ }
+
+ ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+// translateFromStandardVerdict translates verdicts the same way as the iptables
+// tool.
+func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 {
+ switch verdict {
+ case stack.RuleAccept:
+ return -linux.NF_ACCEPT - 1
+ case stack.RuleDrop:
+ return -linux.NF_DROP - 1
+ case stack.RuleReturn:
+ return linux.NF_RETURN
+ default:
+ // TODO(gvisor.dev/issue/170): Support Jump.
+ panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
+ }
+}
+
+// translateToStandardTarget translates from the value in a
+// linux.XTStandardTarget to an stack.Verdict.
+func translateToStandardTarget(val int32) (stack.Target, error) {
+ // TODO(gvisor.dev/issue/170): Support other verdicts.
+ switch val {
+ case -linux.NF_ACCEPT - 1:
+ return stack.AcceptTarget{}, nil
+ case -linux.NF_DROP - 1:
+ return stack.DropTarget{}, nil
+ case -linux.NF_QUEUE - 1:
+ return nil, errors.New("unsupported iptables verdict QUEUE")
+ case linux.NF_RETURN:
+ return stack.ReturnTarget{}, nil
+ default:
+ return nil, fmt.Errorf("unknown iptables verdict %d", val)
+ }
+}
+
+// SetEntries sets iptables rules for a single table. See
+// net/ipv4/netfilter/ip_tables.c:translate_table for reference.
+func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
+ // Get the basic rules data (struct ipt_replace).
+ if len(optVal) < linux.SizeOfIPTReplace {
+ nflog("optVal has insufficient size for replace %d", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ var replace linux.IPTReplace
+ replaceBuf := optVal[:linux.SizeOfIPTReplace]
+ optVal = optVal[linux.SizeOfIPTReplace:]
+ binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace)
+
+ // TODO(gvisor.dev/issue/170): Support other tables.
+ var table stack.Table
+ switch replace.Name.String() {
+ case stack.TablenameFilter:
+ table = stack.EmptyFilterTable()
+ case stack.TablenameNat:
+ table = stack.EmptyNatTable()
+ default:
+ nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
+ return syserr.ErrInvalidArgument
+ }
+
+ nflog("set entries: setting entries in table %q", replace.Name.String())
+
+ // Convert input into a list of rules and their offsets.
+ var offset uint32
+ // offsets maps rule byte offsets to their position in table.Rules.
+ offsets := map[uint32]int{}
+ for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
+ nflog("set entries: processing entry at offset %d", offset)
+
+ // Get the struct ipt_entry.
+ if len(optVal) < linux.SizeOfIPTEntry {
+ nflog("optVal has insufficient size for entry %d", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ var entry linux.IPTEntry
+ buf := optVal[:linux.SizeOfIPTEntry]
+ binary.Unmarshal(buf, usermem.ByteOrder, &entry)
+ initialOptValLen := len(optVal)
+ optVal = optVal[linux.SizeOfIPTEntry:]
+
+ if entry.TargetOffset < linux.SizeOfIPTEntry {
+ nflog("entry has too-small target offset %d", entry.TargetOffset)
+ return syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): We should support more IPTIP
+ // filtering fields.
+ filter, err := filterFromIPTIP(entry.IP)
+ if err != nil {
+ nflog("bad iptip: %v", err)
+ return syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): Matchers and targets can specify
+ // that they only work for certain protocols, hooks, tables.
+ // Get matchers.
+ matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry
+ if len(optVal) < int(matchersSize) {
+ nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ matchers, err := parseMatchers(filter, optVal[:matchersSize])
+ if err != nil {
+ nflog("failed to parse matchers: %v", err)
+ return syserr.ErrInvalidArgument
+ }
+ optVal = optVal[matchersSize:]
+
+ // Get the target of the rule.
+ targetSize := entry.NextOffset - entry.TargetOffset
+ if len(optVal) < int(targetSize) {
+ nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ target, err := parseTarget(filter, optVal[:targetSize])
+ if err != nil {
+ nflog("failed to parse target: %v", err)
+ return syserr.ErrInvalidArgument
+ }
+ optVal = optVal[targetSize:]
+
+ table.Rules = append(table.Rules, stack.Rule{
+ Filter: filter,
+ Target: target,
+ Matchers: matchers,
+ })
+ offsets[offset] = int(entryIdx)
+ offset += uint32(entry.NextOffset)
+
+ if initialOptValLen-len(optVal) != int(entry.NextOffset) {
+ nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal))
+ return syserr.ErrInvalidArgument
+ }
+ }
+
+ // Go through the list of supported hooks for this table and, for each
+ // one, set the rule it corresponds to.
+ for hook, _ := range replace.HookEntry {
+ if table.ValidHooks()&(1<<hook) != 0 {
+ hk := hookFromLinux(hook)
+ for offset, ruleIdx := range offsets {
+ if offset == replace.HookEntry[hook] {
+ table.BuiltinChains[hk] = ruleIdx
+ }
+ if offset == replace.Underflow[hook] {
+ if !validUnderflow(table.Rules[ruleIdx]) {
+ nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx)
+ return syserr.ErrInvalidArgument
+ }
+ table.Underflows[hk] = ruleIdx
+ }
+ }
+ if ruleIdx := table.BuiltinChains[hk]; ruleIdx == stack.HookUnset {
+ nflog("hook %v is unset.", hk)
+ return syserr.ErrInvalidArgument
+ }
+ if ruleIdx := table.Underflows[hk]; ruleIdx == stack.HookUnset {
+ nflog("underflow %v is unset.", hk)
+ return syserr.ErrInvalidArgument
+ }
+ }
+ }
+
+ // Add the user chains.
+ for ruleIdx, rule := range table.Rules {
+ target, ok := rule.Target.(stack.UserChainTarget)
+ if !ok {
+ continue
+ }
+
+ // We found a user chain. Before inserting it into the table,
+ // check that:
+ // - There's some other rule after it.
+ // - There are no matchers.
+ if ruleIdx == len(table.Rules)-1 {
+ nflog("user chain must have a rule or default policy")
+ return syserr.ErrInvalidArgument
+ }
+ if len(table.Rules[ruleIdx].Matchers) != 0 {
+ nflog("user chain's first node must have no matchers")
+ return syserr.ErrInvalidArgument
+ }
+ table.UserChains[target.Name] = ruleIdx + 1
+ }
+
+ // Set each jump to point to the appropriate rule. Right now they hold byte
+ // offsets.
+ for ruleIdx, rule := range table.Rules {
+ jump, ok := rule.Target.(JumpTarget)
+ if !ok {
+ continue
+ }
+
+ // Find the rule corresponding to the jump rule offset.
+ jumpTo, ok := offsets[jump.Offset]
+ if !ok {
+ nflog("failed to find a rule to jump to")
+ return syserr.ErrInvalidArgument
+ }
+ jump.RuleNum = jumpTo
+ rule.Target = jump
+ table.Rules[ruleIdx] = rule
+ }
+
+ // TODO(gvisor.dev/issue/170): Support other chains.
+ // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now,
+ // make sure all other chains point to ACCEPT rules.
+ for hook, ruleIdx := range table.BuiltinChains {
+ if hook == stack.Forward || hook == stack.Postrouting {
+ if !isUnconditionalAccept(table.Rules[ruleIdx]) {
+ nflog("hook %d is unsupported.", hook)
+ return syserr.ErrInvalidArgument
+ }
+ }
+ }
+
+ // TODO(gvisor.dev/issue/170): Check the following conditions:
+ // - There are no loops.
+ // - There are no chains without an unconditional final rule.
+ // - There are no chains without an unconditional underflow rule.
+
+ stk.IPTables().ReplaceTable(replace.Name.String(), table)
+
+ return nil
+}
+
+// parseMatchers parses 0 or more matchers from optVal. optVal should contain
+// only the matchers.
+func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) {
+ nflog("set entries: parsing matchers of size %d", len(optVal))
+ var matchers []stack.Matcher
+ for len(optVal) > 0 {
+ nflog("set entries: optVal has len %d", len(optVal))
+
+ // Get the XTEntryMatch.
+ if len(optVal) < linux.SizeOfXTEntryMatch {
+ return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal))
+ }
+ var match linux.XTEntryMatch
+ buf := optVal[:linux.SizeOfXTEntryMatch]
+ binary.Unmarshal(buf, usermem.ByteOrder, &match)
+ nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match)
+
+ // Check some invariants.
+ if match.MatchSize < linux.SizeOfXTEntryMatch {
+
+ return nil, fmt.Errorf("match size is too small, must be at least %d", linux.SizeOfXTEntryMatch)
+ }
+ if len(optVal) < int(match.MatchSize) {
+ return nil, fmt.Errorf("optVal has insufficient size for match: %d", len(optVal))
+ }
+
+ // Parse the specific matcher.
+ matcher, err := unmarshalMatcher(match, filter, optVal[linux.SizeOfXTEntryMatch:match.MatchSize])
+ if err != nil {
+ return nil, fmt.Errorf("failed to create matcher: %v", err)
+ }
+ matchers = append(matchers, matcher)
+
+ // TODO(gvisor.dev/issue/170): Check the revision field.
+ optVal = optVal[match.MatchSize:]
+ }
+
+ if len(optVal) != 0 {
+ return nil, errors.New("optVal should be exhausted after parsing matchers")
+ }
+
+ return matchers, nil
+}
+
+// parseTarget parses a target from optVal. optVal should contain only the
+// target.
+func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) {
+ nflog("set entries: parsing target of size %d", len(optVal))
+ if len(optVal) < linux.SizeOfXTEntryTarget {
+ return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal))
+ }
+ var target linux.XTEntryTarget
+ buf := optVal[:linux.SizeOfXTEntryTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &target)
+ switch target.Name.String() {
+ case "":
+ // Standard target.
+ if len(optVal) != linux.SizeOfXTStandardTarget {
+ return nil, fmt.Errorf("optVal has wrong size for standard target %d", len(optVal))
+ }
+ var standardTarget linux.XTStandardTarget
+ buf = optVal[:linux.SizeOfXTStandardTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget)
+
+ if standardTarget.Verdict < 0 {
+ // A Verdict < 0 indicates a non-jump verdict.
+ return translateToStandardTarget(standardTarget.Verdict)
+ }
+ // A verdict >= 0 indicates a jump.
+ return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil
+
+ case errorTargetName:
+ // Error target.
+ if len(optVal) != linux.SizeOfXTErrorTarget {
+ return nil, fmt.Errorf("optVal has insufficient size for error target %d", len(optVal))
+ }
+ var errorTarget linux.XTErrorTarget
+ buf = optVal[:linux.SizeOfXTErrorTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget)
+
+ // Error targets are used in 2 cases:
+ // * An actual error case. These rules have an error
+ // named errorTargetName. The last entry of the table
+ // is usually an error case to catch any packets that
+ // somehow fall through every rule.
+ // * To mark the start of a user defined chain. These
+ // rules have an error with the name of the chain.
+ switch name := errorTarget.Name.String(); name {
+ case errorTargetName:
+ nflog("set entries: error target")
+ return stack.ErrorTarget{}, nil
+ default:
+ // User defined chain.
+ nflog("set entries: user-defined target %q", name)
+ return stack.UserChainTarget{Name: name}, nil
+ }
+
+ case redirectTargetName:
+ // Redirect target.
+ if len(optVal) < linux.SizeOfXTRedirectTarget {
+ return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal))
+ }
+
+ if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ var redirectTarget linux.XTRedirectTarget
+ buf = optVal[:linux.SizeOfXTRedirectTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget)
+
+ // Copy linux.XTRedirectTarget to stack.RedirectTarget.
+ var target stack.RedirectTarget
+ nfRange := redirectTarget.NfRange
+
+ // RangeSize should be 1.
+ if nfRange.RangeSize != 1 {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ // TODO(gvisor.dev/issue/170): Check if the flags are valid.
+ // Also check if we need to map ports or IP.
+ // For now, redirect target only supports destination port change.
+ // Port range and IP range are not supported yet.
+ if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+ target.RangeProtoSpecified = true
+
+ target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:])
+
+ // TODO(gvisor.dev/issue/170): Port range is not supported yet.
+ if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ // Convert port from big endian to little endian.
+ port := make([]byte, 2)
+ binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort)
+ target.MinPort = binary.LittleEndian.Uint16(port)
+
+ binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort)
+ target.MaxPort = binary.LittleEndian.Uint16(port)
+ return target, nil
+ }
+
+ // Unknown target.
+ return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet.", target.Name.String())
+}
+
+func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
+ if containsUnsupportedFields(iptip) {
+ return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
+ }
+ if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask))
+ }
+ if len(iptip.Src) != header.IPv4AddressSize || len(iptip.SrcMask) != header.IPv4AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask))
+ }
+
+ n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0)
+ if n == -1 {
+ n = len(iptip.OutputInterface)
+ }
+ ifname := string(iptip.OutputInterface[:n])
+
+ n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0)
+ if n == -1 {
+ n = len(iptip.OutputInterfaceMask)
+ }
+ ifnameMask := string(iptip.OutputInterfaceMask[:n])
+
+ return stack.IPHeaderFilter{
+ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ Dst: tcpip.Address(iptip.Dst[:]),
+ DstMask: tcpip.Address(iptip.DstMask[:]),
+ DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0,
+ Src: tcpip.Address(iptip.Src[:]),
+ SrcMask: tcpip.Address(iptip.SrcMask[:]),
+ SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0,
+ OutputInterface: ifname,
+ OutputInterfaceMask: ifnameMask,
+ OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0,
+ }, nil
+}
+
+func containsUnsupportedFields(iptip linux.IPTIP) bool {
+ // The following features are supported:
+ // - Protocol
+ // - Dst and DstMask
+ // - Src and SrcMask
+ // - The inverse destination IP check flag
+ // - OutputInterface, OutputInterfaceMask and its inverse.
+ var emptyInterface = [linux.IFNAMSIZ]byte{}
+ // Disable any supported inverse flags.
+ inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT)
+ return iptip.InputInterface != emptyInterface ||
+ iptip.InputInterfaceMask != emptyInterface ||
+ iptip.Flags != 0 ||
+ iptip.InverseFlags&^inverseMask != 0
+}
+
+func validUnderflow(rule stack.Rule) bool {
+ if len(rule.Matchers) != 0 {
+ return false
+ }
+ if rule.Filter != emptyFilter {
+ return false
+ }
+ switch rule.Target.(type) {
+ case stack.AcceptTarget, stack.DropTarget:
+ return true
+ default:
+ return false
+ }
+}
+
+func isUnconditionalAccept(rule stack.Rule) bool {
+ if !validUnderflow(rule) {
+ return false
+ }
+ _, ok := rule.Target.(stack.AcceptTarget)
+ return ok
+}
+
+func hookFromLinux(hook int) stack.Hook {
+ switch hook {
+ case linux.NF_INET_PRE_ROUTING:
+ return stack.Prerouting
+ case linux.NF_INET_LOCAL_IN:
+ return stack.Input
+ case linux.NF_INET_FORWARD:
+ return stack.Forward
+ case linux.NF_INET_LOCAL_OUT:
+ return stack.Output
+ case linux.NF_INET_POST_ROUTING:
+ return stack.Postrouting
+ }
+ panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook))
+}
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
new file mode 100644
index 000000000..1b4e0ad79
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -0,0 +1,149 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const matcherNameOwner = "owner"
+
+func init() {
+ registerMatchMaker(ownerMarshaler{})
+}
+
+// ownerMarshaler implements matchMaker for owner matching.
+type ownerMarshaler struct{}
+
+// name implements matchMaker.name.
+func (ownerMarshaler) name() string {
+ return matcherNameOwner
+}
+
+// marshal implements matchMaker.marshal.
+func (ownerMarshaler) marshal(mr stack.Matcher) []byte {
+ matcher := mr.(*OwnerMatcher)
+ iptOwnerInfo := linux.IPTOwnerInfo{
+ UID: matcher.uid,
+ GID: matcher.gid,
+ }
+
+ // Support for UID and GID match.
+ if matcher.matchUID {
+ iptOwnerInfo.Match = linux.XT_OWNER_UID
+ if matcher.invertUID {
+ iptOwnerInfo.Invert = linux.XT_OWNER_UID
+ }
+ }
+ if matcher.matchGID {
+ iptOwnerInfo.Match |= linux.XT_OWNER_GID
+ if matcher.invertGID {
+ iptOwnerInfo.Invert |= linux.XT_OWNER_GID
+ }
+ }
+
+ buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo)
+ return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, usermem.ByteOrder, iptOwnerInfo))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+ if len(buf) < linux.SizeOfIPTOwnerInfo {
+ return nil, fmt.Errorf("buf has insufficient size for owner match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may
+ // exceed what's strictly necessary to hold matchData.
+ var matchData linux.IPTOwnerInfo
+ binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData)
+
+ var owner OwnerMatcher
+ owner.uid = matchData.UID
+ owner.gid = matchData.GID
+
+ // Check flags.
+ if matchData.Match&linux.XT_OWNER_UID != 0 {
+ owner.matchUID = true
+ if matchData.Invert&linux.XT_OWNER_UID != 0 {
+ owner.invertUID = true
+ }
+ }
+ if matchData.Match&linux.XT_OWNER_GID != 0 {
+ owner.matchGID = true
+ if matchData.Invert&linux.XT_OWNER_GID != 0 {
+ owner.invertGID = true
+ }
+ }
+
+ return &owner, nil
+}
+
+type OwnerMatcher struct {
+ uid uint32
+ gid uint32
+ matchUID bool
+ matchGID bool
+ invertUID bool
+ invertGID bool
+}
+
+// Name implements Matcher.Name.
+func (*OwnerMatcher) Name() string {
+ return matcherNameOwner
+}
+
+// Match implements Matcher.Match.
+func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
+ // Support only for OUTPUT chain.
+ // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also.
+ if hook != stack.Output {
+ return false, true
+ }
+
+ // If the packet owner is not set, drop the packet.
+ if pkt.Owner == nil {
+ return false, true
+ }
+
+ var matches bool
+ // Check for UID match.
+ if om.matchUID {
+ if pkt.Owner.UID() == om.uid {
+ matches = true
+ }
+ if matches == om.invertUID {
+ return false, false
+ }
+ }
+
+ // Check for GID match.
+ if om.matchGID {
+ matches = false
+ if pkt.Owner.GID() == om.gid {
+ matches = true
+ }
+ if matches == om.invertGID {
+ return false, false
+ }
+ }
+
+ return true, false
+}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
new file mode 100644
index 000000000..b91ba3ab3
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -0,0 +1,35 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// JumpTarget implements stack.Target.
+type JumpTarget struct {
+ // Offset is the byte offset of the rule to jump to. It is used for
+ // marshaling and unmarshaling.
+ Offset uint32
+
+ // RuleNum is the rule to jump to.
+ RuleNum int
+}
+
+// Action implements stack.Target.Action.
+func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) {
+ return stack.RuleJump, jt.RuleNum
+}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
new file mode 100644
index 000000000..4f98ee2d5
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -0,0 +1,130 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const matcherNameTCP = "tcp"
+
+func init() {
+ registerMatchMaker(tcpMarshaler{})
+}
+
+// tcpMarshaler implements matchMaker for TCP matching.
+type tcpMarshaler struct{}
+
+// name implements matchMaker.name.
+func (tcpMarshaler) name() string {
+ return matcherNameTCP
+}
+
+// marshal implements matchMaker.marshal.
+func (tcpMarshaler) marshal(mr stack.Matcher) []byte {
+ matcher := mr.(*TCPMatcher)
+ xttcp := linux.XTTCP{
+ SourcePortStart: matcher.sourcePortStart,
+ SourcePortEnd: matcher.sourcePortEnd,
+ DestinationPortStart: matcher.destinationPortStart,
+ DestinationPortEnd: matcher.destinationPortEnd,
+ }
+ buf := make([]byte, 0, linux.SizeOfXTTCP)
+ return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, usermem.ByteOrder, xttcp))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+ if len(buf) < linux.SizeOfXTTCP {
+ return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may
+ // exceed what's strictly necessary to hold matchData.
+ var matchData linux.XTTCP
+ binary.Unmarshal(buf[:linux.SizeOfXTTCP], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed XTTCP: %+v", matchData)
+
+ if matchData.Option != 0 ||
+ matchData.FlagMask != 0 ||
+ matchData.FlagCompare != 0 ||
+ matchData.InverseFlags != 0 {
+ return nil, fmt.Errorf("unsupported TCP matcher flags set")
+ }
+
+ if filter.Protocol != header.TCPProtocolNumber {
+ return nil, fmt.Errorf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber)
+ }
+
+ return &TCPMatcher{
+ sourcePortStart: matchData.SourcePortStart,
+ sourcePortEnd: matchData.SourcePortEnd,
+ destinationPortStart: matchData.DestinationPortStart,
+ destinationPortEnd: matchData.DestinationPortEnd,
+ }, nil
+}
+
+// TCPMatcher matches TCP packets and their headers. It implements Matcher.
+type TCPMatcher struct {
+ sourcePortStart uint16
+ sourcePortEnd uint16
+ destinationPortStart uint16
+ destinationPortEnd uint16
+}
+
+// Name implements Matcher.Name.
+func (*TCPMatcher) Name() string {
+ return matcherNameTCP
+}
+
+// Match implements Matcher.Match.
+func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
+ netHeader := header.IPv4(pkt.NetworkHeader)
+
+ if netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ return false, false
+ }
+
+ // We dont't match fragments.
+ if frag := netHeader.FragmentOffset(); frag != 0 {
+ if frag == 1 {
+ return false, true
+ }
+ return false, false
+ }
+
+ tcpHeader := header.TCP(pkt.TransportHeader)
+ if len(tcpHeader) < header.TCPMinimumSize {
+ // There's no valid TCP header here, so we drop the packet immediately.
+ return false, true
+ }
+
+ // Check whether the source and destination ports are within the
+ // matching range.
+ if sourcePort := tcpHeader.SourcePort(); sourcePort < tm.sourcePortStart || tm.sourcePortEnd < sourcePort {
+ return false, false
+ }
+ if destinationPort := tcpHeader.DestinationPort(); destinationPort < tm.destinationPortStart || tm.destinationPortEnd < destinationPort {
+ return false, false
+ }
+
+ return true, false
+}
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
new file mode 100644
index 000000000..3f20fc891
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -0,0 +1,129 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const matcherNameUDP = "udp"
+
+func init() {
+ registerMatchMaker(udpMarshaler{})
+}
+
+// udpMarshaler implements matchMaker for UDP matching.
+type udpMarshaler struct{}
+
+// name implements matchMaker.name.
+func (udpMarshaler) name() string {
+ return matcherNameUDP
+}
+
+// marshal implements matchMaker.marshal.
+func (udpMarshaler) marshal(mr stack.Matcher) []byte {
+ matcher := mr.(*UDPMatcher)
+ xtudp := linux.XTUDP{
+ SourcePortStart: matcher.sourcePortStart,
+ SourcePortEnd: matcher.sourcePortEnd,
+ DestinationPortStart: matcher.destinationPortStart,
+ DestinationPortEnd: matcher.destinationPortEnd,
+ }
+ buf := make([]byte, 0, linux.SizeOfXTUDP)
+ return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, usermem.ByteOrder, xtudp))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+ if len(buf) < linux.SizeOfXTUDP {
+ return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may exceed what's
+ // strictly necessary to hold matchData.
+ var matchData linux.XTUDP
+ binary.Unmarshal(buf[:linux.SizeOfXTUDP], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed XTUDP: %+v", matchData)
+
+ if matchData.InverseFlags != 0 {
+ return nil, fmt.Errorf("unsupported UDP matcher inverse flags set")
+ }
+
+ if filter.Protocol != header.UDPProtocolNumber {
+ return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber)
+ }
+
+ return &UDPMatcher{
+ sourcePortStart: matchData.SourcePortStart,
+ sourcePortEnd: matchData.SourcePortEnd,
+ destinationPortStart: matchData.DestinationPortStart,
+ destinationPortEnd: matchData.DestinationPortEnd,
+ }, nil
+}
+
+// UDPMatcher matches UDP packets and their headers. It implements Matcher.
+type UDPMatcher struct {
+ sourcePortStart uint16
+ sourcePortEnd uint16
+ destinationPortStart uint16
+ destinationPortEnd uint16
+}
+
+// Name implements Matcher.Name.
+func (*UDPMatcher) Name() string {
+ return matcherNameUDP
+}
+
+// Match implements Matcher.Match.
+func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
+ netHeader := header.IPv4(pkt.NetworkHeader)
+
+ // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
+ // into the stack.Check codepath as matchers are added.
+ if netHeader.TransportProtocol() != header.UDPProtocolNumber {
+ return false, false
+ }
+
+ // We dont't match fragments.
+ if frag := netHeader.FragmentOffset(); frag != 0 {
+ if frag == 1 {
+ return false, true
+ }
+ return false, false
+ }
+
+ udpHeader := header.UDP(pkt.TransportHeader)
+ if len(udpHeader) < header.UDPMinimumSize {
+ // There's no valid UDP header here, so we drop the packet immediately.
+ return false, true
+ }
+
+ // Check whether the source and destination ports are within the
+ // matching range.
+ if sourcePort := udpHeader.SourcePort(); sourcePort < um.sourcePortStart || um.sourcePortEnd < sourcePort {
+ return false, false
+ }
+ if destinationPort := udpHeader.DestinationPort(); destinationPort < um.destinationPortStart || um.destinationPortEnd < destinationPort {
+ return false, false
+ }
+
+ return true, false
+}
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
new file mode 100644
index 000000000..d5ca3ac56
--- /dev/null
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -0,0 +1,52 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "netlink",
+ srcs = [
+ "message.go",
+ "provider.go",
+ "provider_vfs2.go",
+ "socket.go",
+ "socket_vfs2.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/sockfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/netlink/port",
+ "//pkg/sentry/socket/unix",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
+
+go_test(
+ name = "netlink_test",
+ size = "small",
+ srcs = [
+ "message_test.go",
+ ],
+ deps = [
+ ":netlink",
+ "//pkg/abi/linux",
+ ],
+)
diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go
new file mode 100644
index 000000000..0899c61d1
--- /dev/null
+++ b/pkg/sentry/socket/netlink/message.go
@@ -0,0 +1,281 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netlink
+
+import (
+ "fmt"
+ "math"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// alignPad returns the length of padding required for alignment.
+//
+// Preconditions: align is a power of two.
+func alignPad(length int, align uint) int {
+ return binary.AlignUp(length, align) - length
+}
+
+// Message contains a complete serialized netlink message.
+type Message struct {
+ hdr linux.NetlinkMessageHeader
+ buf []byte
+}
+
+// NewMessage creates a new Message containing the passed header.
+//
+// The header length will be updated by Finalize.
+func NewMessage(hdr linux.NetlinkMessageHeader) *Message {
+ return &Message{
+ hdr: hdr,
+ buf: binary.Marshal(nil, usermem.ByteOrder, hdr),
+ }
+}
+
+// ParseMessage parses the first message seen at buf, returning the rest of the
+// buffer. If message is malformed, ok of false is returned. For last message,
+// padding check is loose, if there isn't enought padding, whole buf is consumed
+// and ok is set to true.
+func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) {
+ b := BytesView(buf)
+
+ hdrBytes, ok := b.Extract(linux.NetlinkMessageHeaderSize)
+ if !ok {
+ return
+ }
+ var hdr linux.NetlinkMessageHeader
+ binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr)
+
+ // Msg portion.
+ totalMsgLen := int(hdr.Length)
+ _, ok = b.Extract(totalMsgLen - linux.NetlinkMessageHeaderSize)
+ if !ok {
+ return
+ }
+
+ // Padding.
+ numPad := alignPad(totalMsgLen, linux.NLMSG_ALIGNTO)
+ // Linux permits the last message not being aligned, just consume all of it.
+ // Ref: net/netlink/af_netlink.c:netlink_rcv_skb
+ if numPad > len(b) {
+ numPad = len(b)
+ }
+ _, ok = b.Extract(numPad)
+ if !ok {
+ return
+ }
+
+ return &Message{
+ hdr: hdr,
+ buf: buf[:totalMsgLen],
+ }, []byte(b), true
+}
+
+// Header returns the header of this message.
+func (m *Message) Header() linux.NetlinkMessageHeader {
+ return m.hdr
+}
+
+// GetData unmarshals the payload message header from this netlink message, and
+// returns the attributes portion.
+func (m *Message) GetData(msg interface{}) (AttrsView, bool) {
+ b := BytesView(m.buf)
+
+ _, ok := b.Extract(linux.NetlinkMessageHeaderSize)
+ if !ok {
+ return nil, false
+ }
+
+ size := int(binary.Size(msg))
+ msgBytes, ok := b.Extract(size)
+ if !ok {
+ return nil, false
+ }
+ binary.Unmarshal(msgBytes, usermem.ByteOrder, msg)
+
+ numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO)
+ // Linux permits the last message not being aligned, just consume all of it.
+ // Ref: net/netlink/af_netlink.c:netlink_rcv_skb
+ if numPad > len(b) {
+ numPad = len(b)
+ }
+ _, ok = b.Extract(numPad)
+ if !ok {
+ return nil, false
+ }
+
+ return AttrsView(b), true
+}
+
+// Finalize returns the []byte containing the entire message, with the total
+// length set in the message header. The Message must not be modified after
+// calling Finalize.
+func (m *Message) Finalize() []byte {
+ // Update length, which is the first 4 bytes of the header.
+ usermem.ByteOrder.PutUint32(m.buf, uint32(len(m.buf)))
+
+ // Align the message. Note that the message length in the header (set
+ // above) is the useful length of the message, not the total aligned
+ // length. See net/netlink/af_netlink.c:__nlmsg_put.
+ aligned := binary.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO)
+ m.putZeros(aligned - len(m.buf))
+ return m.buf
+}
+
+// putZeros adds n zeros to the message.
+func (m *Message) putZeros(n int) {
+ for n > 0 {
+ m.buf = append(m.buf, 0)
+ n--
+ }
+}
+
+// Put serializes v into the message.
+func (m *Message) Put(v interface{}) {
+ m.buf = binary.Marshal(m.buf, usermem.ByteOrder, v)
+}
+
+// PutAttr adds v to the message as a netlink attribute.
+//
+// Preconditions: The serialized attribute (linux.NetlinkAttrHeaderSize +
+// binary.Size(v) fits in math.MaxUint16 bytes.
+func (m *Message) PutAttr(atype uint16, v interface{}) {
+ l := linux.NetlinkAttrHeaderSize + int(binary.Size(v))
+ if l > math.MaxUint16 {
+ panic(fmt.Sprintf("attribute too large: %d", l))
+ }
+
+ m.Put(linux.NetlinkAttrHeader{
+ Type: atype,
+ Length: uint16(l),
+ })
+ m.Put(v)
+
+ // Align the attribute.
+ aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
+ m.putZeros(aligned - l)
+}
+
+// PutAttrString adds s to the message as a netlink attribute.
+func (m *Message) PutAttrString(atype uint16, s string) {
+ l := linux.NetlinkAttrHeaderSize + len(s) + 1
+ m.Put(linux.NetlinkAttrHeader{
+ Type: atype,
+ Length: uint16(l),
+ })
+
+ // String + NUL-termination.
+ m.Put([]byte(s))
+ m.putZeros(1)
+
+ // Align the attribute.
+ aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
+ m.putZeros(aligned - l)
+}
+
+// MessageSet contains a series of netlink messages.
+type MessageSet struct {
+ // Multi indicates that this a multi-part message, to be terminated by
+ // NLMSG_DONE. NLMSG_DONE is sent even if the set contains only one
+ // Message.
+ //
+ // If Multi is set, all added messages will have NLM_F_MULTI set.
+ Multi bool
+
+ // PortID is the destination port for all messages.
+ PortID int32
+
+ // Seq is the sequence counter for all messages in the set.
+ Seq uint32
+
+ // Messages contains the messages in the set.
+ Messages []*Message
+}
+
+// NewMessageSet creates a new MessageSet.
+//
+// portID is the destination port to set as PortID in all messages.
+//
+// seq is the sequence counter to set as seq in all messages in the set.
+func NewMessageSet(portID int32, seq uint32) *MessageSet {
+ return &MessageSet{
+ PortID: portID,
+ Seq: seq,
+ }
+}
+
+// AddMessage adds a new message to the set and returns it for further
+// additions.
+//
+// The passed header will have Seq, PortID and the multi flag set
+// automatically.
+func (ms *MessageSet) AddMessage(hdr linux.NetlinkMessageHeader) *Message {
+ hdr.Seq = ms.Seq
+ hdr.PortID = uint32(ms.PortID)
+ if ms.Multi {
+ hdr.Flags |= linux.NLM_F_MULTI
+ }
+
+ m := NewMessage(hdr)
+ ms.Messages = append(ms.Messages, m)
+ return m
+}
+
+// AttrsView is a view into the attributes portion of a netlink message.
+type AttrsView []byte
+
+// Empty returns whether there is no attribute left in v.
+func (v AttrsView) Empty() bool {
+ return len(v) == 0
+}
+
+// ParseFirst parses first netlink attribute at the beginning of v.
+func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest AttrsView, ok bool) {
+ b := BytesView(v)
+
+ hdrBytes, ok := b.Extract(linux.NetlinkAttrHeaderSize)
+ if !ok {
+ return
+ }
+ binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr)
+
+ value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize)
+ if !ok {
+ return
+ }
+
+ _, ok = b.Extract(alignPad(int(hdr.Length), linux.NLA_ALIGNTO))
+ if !ok {
+ return
+ }
+
+ return hdr, value, AttrsView(b), ok
+}
+
+// BytesView supports extracting data from a byte slice with bounds checking.
+type BytesView []byte
+
+// Extract removes the first n bytes from v and returns it. If n is out of
+// bounds, it returns false.
+func (v *BytesView) Extract(n int) ([]byte, bool) {
+ if n < 0 || n > len(*v) {
+ return nil, false
+ }
+ extracted := (*v)[:n]
+ *v = (*v)[n:]
+ return extracted, true
+}
diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go
new file mode 100644
index 000000000..ef13d9386
--- /dev/null
+++ b/pkg/sentry/socket/netlink/message_test.go
@@ -0,0 +1,312 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package message_test
+
+import (
+ "bytes"
+ "reflect"
+ "testing"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netlink"
+)
+
+type dummyNetlinkMsg struct {
+ Foo uint16
+}
+
+func TestParseMessage(t *testing.T) {
+ tests := []struct {
+ desc string
+ input []byte
+
+ header linux.NetlinkMessageHeader
+ dataMsg *dummyNetlinkMsg
+ restLen int
+ ok bool
+ }{
+ {
+ desc: "valid",
+ input: []byte{
+ 0x14, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
+ },
+ header: linux.NetlinkMessageHeader{
+ Length: 20,
+ Type: 1,
+ Flags: 2,
+ Seq: 3,
+ PortID: 4,
+ },
+ dataMsg: &dummyNetlinkMsg{
+ Foo: 0x3130,
+ },
+ restLen: 0,
+ ok: true,
+ },
+ {
+ desc: "valid with next message",
+ input: []byte{
+ 0x14, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
+ 0xFF, // Next message (rest)
+ },
+ header: linux.NetlinkMessageHeader{
+ Length: 20,
+ Type: 1,
+ Flags: 2,
+ Seq: 3,
+ PortID: 4,
+ },
+ dataMsg: &dummyNetlinkMsg{
+ Foo: 0x3130,
+ },
+ restLen: 1,
+ ok: true,
+ },
+ {
+ desc: "valid for last message without padding",
+ input: []byte{
+ 0x12, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, // Data message
+ },
+ header: linux.NetlinkMessageHeader{
+ Length: 18,
+ Type: 1,
+ Flags: 2,
+ Seq: 3,
+ PortID: 4,
+ },
+ dataMsg: &dummyNetlinkMsg{
+ Foo: 0x3130,
+ },
+ restLen: 0,
+ ok: true,
+ },
+ {
+ desc: "valid for last message not to be aligned",
+ input: []byte{
+ 0x13, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, // Data message
+ 0x00, // Excessive 1 byte permitted at end
+ },
+ header: linux.NetlinkMessageHeader{
+ Length: 19,
+ Type: 1,
+ Flags: 2,
+ Seq: 3,
+ PortID: 4,
+ },
+ dataMsg: &dummyNetlinkMsg{
+ Foo: 0x3130,
+ },
+ restLen: 0,
+ ok: true,
+ },
+ {
+ desc: "header.Length too short",
+ input: []byte{
+ 0x04, 0x00, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
+ },
+ ok: false,
+ },
+ {
+ desc: "header.Length too long",
+ input: []byte{
+ 0xFF, 0xFF, 0x00, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x02, 0x00, // Flags
+ 0x03, 0x00, 0x00, 0x00, // Seq
+ 0x04, 0x00, 0x00, 0x00, // PortID
+ 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding
+ },
+ ok: false,
+ },
+ {
+ desc: "header incomplete",
+ input: []byte{
+ 0x04, 0x00, 0x00, 0x00, // Length
+ },
+ ok: false,
+ },
+ {
+ desc: "empty message",
+ input: []byte{},
+ ok: false,
+ },
+ }
+ for _, test := range tests {
+ msg, rest, ok := netlink.ParseMessage(test.input)
+ if ok != test.ok {
+ t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok)
+ continue
+ }
+ if !test.ok {
+ continue
+ }
+ if !reflect.DeepEqual(msg.Header(), test.header) {
+ t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header)
+ }
+
+ dataMsg := &dummyNetlinkMsg{}
+ _, dataOk := msg.GetData(dataMsg)
+ if !dataOk {
+ t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk)
+ } else if !reflect.DeepEqual(dataMsg, test.dataMsg) {
+ t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg)
+ }
+
+ if got, want := rest, test.input[len(test.input)-test.restLen:]; !bytes.Equal(got, want) {
+ t.Errorf("%v: got rest = %v, want = %v", test.desc, got, want)
+ }
+ }
+}
+
+func TestAttrView(t *testing.T) {
+ tests := []struct {
+ desc string
+ input []byte
+
+ // Outputs for ParseFirst.
+ hdr linux.NetlinkAttrHeader
+ value []byte
+ restLen int
+ ok bool
+
+ // Outputs for Empty.
+ isEmpty bool
+ }{
+ {
+ desc: "valid",
+ input: []byte{
+ 0x06, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x00, 0x00, // Data with 2 bytes padding
+ },
+ hdr: linux.NetlinkAttrHeader{
+ Length: 6,
+ Type: 1,
+ },
+ value: []byte{0x30, 0x31},
+ restLen: 0,
+ ok: true,
+ isEmpty: false,
+ },
+ {
+ desc: "at alignment",
+ input: []byte{
+ 0x08, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x32, 0x33, // Data
+ },
+ hdr: linux.NetlinkAttrHeader{
+ Length: 8,
+ Type: 1,
+ },
+ value: []byte{0x30, 0x31, 0x32, 0x33},
+ restLen: 0,
+ ok: true,
+ isEmpty: false,
+ },
+ {
+ desc: "at alignment with rest data",
+ input: []byte{
+ 0x08, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x32, 0x33, // Data
+ 0xFF, 0xFE, // Rest data
+ },
+ hdr: linux.NetlinkAttrHeader{
+ Length: 8,
+ Type: 1,
+ },
+ value: []byte{0x30, 0x31, 0x32, 0x33},
+ restLen: 2,
+ ok: true,
+ isEmpty: false,
+ },
+ {
+ desc: "hdr.Length too long",
+ input: []byte{
+ 0xFF, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x32, 0x33, // Data
+ },
+ ok: false,
+ isEmpty: false,
+ },
+ {
+ desc: "hdr.Length too short",
+ input: []byte{
+ 0x01, 0x00, // Length
+ 0x01, 0x00, // Type
+ 0x30, 0x31, 0x32, 0x33, // Data
+ },
+ ok: false,
+ isEmpty: false,
+ },
+ {
+ desc: "empty",
+ input: []byte{},
+ ok: false,
+ isEmpty: true,
+ },
+ }
+ for _, test := range tests {
+ attrs := netlink.AttrsView(test.input)
+
+ // Test ParseFirst().
+ hdr, value, rest, ok := attrs.ParseFirst()
+ if ok != test.ok {
+ t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok)
+ } else if test.ok {
+ if !reflect.DeepEqual(hdr, test.hdr) {
+ t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, hdr, test.hdr)
+ }
+ if !bytes.Equal(value, test.value) {
+ t.Errorf("%v: got value = %v, want = %v", test.desc, value, test.value)
+ }
+ if wantRest := test.input[len(test.input)-test.restLen:]; !bytes.Equal(rest, wantRest) {
+ t.Errorf("%v: got rest = %v, want = %v", test.desc, rest, wantRest)
+ }
+ }
+
+ // Test Empty().
+ if got, want := attrs.Empty(), test.isEmpty; got != want {
+ t.Errorf("%v: got empty = %v, want = %v", test.desc, got, want)
+ }
+ }
+}
diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD
new file mode 100644
index 000000000..3a22923d8
--- /dev/null
+++ b/pkg/sentry/socket/netlink/port/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "port",
+ srcs = ["port.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/sync"],
+)
+
+go_test(
+ name = "port_test",
+ srcs = ["port_test.go"],
+ library = ":port",
+)
diff --git a/pkg/sentry/socket/netlink/port/port.go b/pkg/sentry/socket/netlink/port/port.go
new file mode 100644
index 000000000..2cd3afc22
--- /dev/null
+++ b/pkg/sentry/socket/netlink/port/port.go
@@ -0,0 +1,117 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package port provides port ID allocation for netlink sockets.
+//
+// A netlink port is any int32 value. Positive ports are typically equivalent
+// to the PID of the binding process. If that port is unavailable, negative
+// ports are searched to find a free port that will not conflict with other
+// PIDS.
+package port
+
+import (
+ "fmt"
+ "math"
+ "math/rand"
+
+ "gvisor.dev/gvisor/pkg/sync"
+)
+
+// maxPorts is a sanity limit on the maximum number of ports to allocate per
+// protocol.
+const maxPorts = 10000
+
+// Manager allocates netlink port IDs.
+//
+// +stateify savable
+type Manager struct {
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // ports contains a map of allocated ports for each protocol.
+ ports map[int]map[int32]struct{}
+}
+
+// New creates a new Manager.
+func New() *Manager {
+ return &Manager{
+ ports: make(map[int]map[int32]struct{}),
+ }
+}
+
+// Allocate reserves a new port ID for protocol. hint will be taken if
+// available.
+func (m *Manager) Allocate(protocol int, hint int32) (int32, bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ proto, ok := m.ports[protocol]
+ if !ok {
+ proto = make(map[int32]struct{})
+ // Port 0 is reserved for the kernel.
+ proto[0] = struct{}{}
+ m.ports[protocol] = proto
+ }
+
+ if len(proto) >= maxPorts {
+ return 0, false
+ }
+
+ if _, ok := proto[hint]; !ok {
+ // Hint is available, reserve it.
+ proto[hint] = struct{}{}
+ return hint, true
+ }
+
+ // Search for any free port in [math.MinInt32, -4096). The positive
+ // port space is left open for pid-based allocations. This behavior is
+ // consistent with Linux.
+ start := int32(math.MinInt32 + rand.Int63n(math.MaxInt32-4096+1))
+ curr := start
+ for {
+ if _, ok := proto[curr]; !ok {
+ proto[curr] = struct{}{}
+ return curr, true
+ }
+
+ curr--
+ if curr >= -4096 {
+ curr = -4097
+ }
+ if curr == start {
+ // Nothing found. We should always find a free port
+ // because maxPorts < -4096 - MinInt32.
+ panic(fmt.Sprintf("No free port found in %+v", proto))
+ }
+ }
+}
+
+// Release frees the specified port for protocol.
+//
+// Preconditions: port is already allocated.
+func (m *Manager) Release(protocol int, port int32) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ proto, ok := m.ports[protocol]
+ if !ok {
+ panic(fmt.Sprintf("Released port %d for protocol %d which has no allocations", port, protocol))
+ }
+
+ if _, ok := proto[port]; !ok {
+ panic(fmt.Sprintf("Released port %d for protocol %d is not allocated", port, protocol))
+ }
+
+ delete(proto, port)
+}
diff --git a/pkg/sentry/socket/netlink/port/port_test.go b/pkg/sentry/socket/netlink/port/port_test.go
new file mode 100644
index 000000000..516f6cd6c
--- /dev/null
+++ b/pkg/sentry/socket/netlink/port/port_test.go
@@ -0,0 +1,82 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package port
+
+import (
+ "testing"
+)
+
+func TestAllocateHint(t *testing.T) {
+ m := New()
+
+ // We can get the hint port.
+ p, ok := m.Allocate(0, 1)
+ if !ok {
+ t.Errorf("m.Allocate got !ok want ok")
+ }
+ if p != 1 {
+ t.Errorf("m.Allocate(0, 1) got %d want 1", p)
+ }
+
+ // Hint is taken.
+ p, ok = m.Allocate(0, 1)
+ if !ok {
+ t.Errorf("m.Allocate got !ok want ok")
+ }
+ if p == 1 {
+ t.Errorf("m.Allocate(0, 1) got 1 want anything else")
+ }
+
+ // Hint is available for a different protocol.
+ p, ok = m.Allocate(1, 1)
+ if !ok {
+ t.Errorf("m.Allocate got !ok want ok")
+ }
+ if p != 1 {
+ t.Errorf("m.Allocate(1, 1) got %d want 1", p)
+ }
+
+ m.Release(0, 1)
+
+ // Hint is available again after release.
+ p, ok = m.Allocate(0, 1)
+ if !ok {
+ t.Errorf("m.Allocate got !ok want ok")
+ }
+ if p != 1 {
+ t.Errorf("m.Allocate(0, 1) got %d want 1", p)
+ }
+}
+
+func TestAllocateExhausted(t *testing.T) {
+ m := New()
+
+ // Fill all ports (0 is already reserved).
+ for i := int32(1); i < maxPorts; i++ {
+ p, ok := m.Allocate(0, i)
+ if !ok {
+ t.Fatalf("m.Allocate got !ok want ok")
+ }
+ if p != i {
+ t.Fatalf("m.Allocate(0, %d) got %d want %d", i, p, i)
+ }
+ }
+
+ // Now no more can be allocated.
+ p, ok := m.Allocate(0, 1)
+ if ok {
+ t.Errorf("m.Allocate got %d, ok want !ok", p)
+ }
+}
diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go
new file mode 100644
index 000000000..0d45e5053
--- /dev/null
+++ b/pkg/sentry/socket/netlink/provider.go
@@ -0,0 +1,116 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netlink
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+// Protocol is the implementation of a netlink socket protocol.
+type Protocol interface {
+ // Protocol returns the Linux netlink protocol value.
+ Protocol() int
+
+ // CanSend returns true if this protocol may ever send messages.
+ //
+ // TODO(gvisor.dev/issue/1119): This is a workaround to allow
+ // advertising support for otherwise unimplemented features on sockets
+ // that will never send messages, thus making those features no-ops.
+ CanSend() bool
+
+ // ProcessMessage processes a single message from userspace.
+ //
+ // If err == nil, any messages added to ms will be sent back to the
+ // other end of the socket. Setting ms.Multi will cause an NLMSG_DONE
+ // message to be sent even if ms contains no messages.
+ ProcessMessage(ctx context.Context, msg *Message, ms *MessageSet) *syserr.Error
+}
+
+// Provider is a function that creates a new Protocol for a specific netlink
+// protocol.
+//
+// Note that this is distinct from socket.Provider, which is used for all
+// socket families.
+type Provider func(t *kernel.Task) (Protocol, *syserr.Error)
+
+// protocols holds a map of all known address protocols and their provider.
+var protocols = make(map[int]Provider)
+
+// RegisterProvider registers the provider of a given address protocol so that
+// netlink sockets of that type can be created via socket(2).
+//
+// Preconditions: May only be called before any netlink sockets are created.
+func RegisterProvider(protocol int, provider Provider) {
+ if p, ok := protocols[protocol]; ok {
+ panic(fmt.Sprintf("Netlink protocol %d already provided by %+v", protocol, p))
+ }
+
+ protocols[protocol] = provider
+}
+
+// LINT.IfChange
+
+// socketProvider implements socket.Provider.
+type socketProvider struct {
+}
+
+// Socket implements socket.Provider.Socket.
+func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Netlink sockets must be specified as datagram or raw, but they
+ // behave the same regardless of type.
+ if stype != linux.SOCK_DGRAM && stype != linux.SOCK_RAW {
+ return nil, syserr.ErrSocketNotSupported
+ }
+
+ provider, ok := protocols[protocol]
+ if !ok {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ p, err := provider(t)
+ if err != nil {
+ return nil, err
+ }
+
+ s, err := NewSocket(t, stype, p)
+ if err != nil {
+ return nil, err
+ }
+
+ d := socket.NewDirent(t, netlinkSocketDevice)
+ defer d.DecRef()
+ return fs.NewFile(t, d, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, s), nil
+}
+
+// Pair implements socket.Provider.Pair by returning an error.
+func (*socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
+ // Netlink sockets never supports creating socket pairs.
+ return nil, nil, syserr.ErrNotSupported
+}
+
+// LINT.ThenChange(./provider_vfs2.go)
+
+// init registers the socket provider.
+func init() {
+ socket.RegisterProvider(linux.AF_NETLINK, &socketProvider{})
+ socket.RegisterProviderVFS2(linux.AF_NETLINK, &socketProviderVFS2{})
+}
diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go
new file mode 100644
index 000000000..bb205be0d
--- /dev/null
+++ b/pkg/sentry/socket/netlink/provider_vfs2.go
@@ -0,0 +1,69 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netlink
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+// socketProviderVFS2 implements socket.Provider.
+type socketProviderVFS2 struct {
+}
+
+// Socket implements socket.Provider.Socket.
+func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Netlink sockets must be specified as datagram or raw, but they
+ // behave the same regardless of type.
+ if stype != linux.SOCK_DGRAM && stype != linux.SOCK_RAW {
+ return nil, syserr.ErrSocketNotSupported
+ }
+
+ provider, ok := protocols[protocol]
+ if !ok {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ p, err := provider(t)
+ if err != nil {
+ return nil, err
+ }
+
+ s, err := NewVFS2(t, stype, p)
+ if err != nil {
+ return nil, err
+ }
+
+ vfsfd := &s.vfsfd
+ mnt := t.Kernel().SocketMount()
+ d := sockfs.NewDentry(t.Credentials(), mnt)
+ if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return vfsfd, nil
+}
+
+// Pair implements socket.Provider.Pair by returning an error.
+func (*socketProviderVFS2) Pair(*kernel.Task, linux.SockType, int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ // Netlink sockets never supports creating socket pairs.
+ return nil, nil, syserr.ErrNotSupported
+}
diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD
new file mode 100644
index 000000000..93127398d
--- /dev/null
+++ b/pkg/sentry/socket/netlink/route/BUILD
@@ -0,0 +1,20 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "route",
+ srcs = [
+ "protocol.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/socket/netlink",
+ "//pkg/syserr",
+ ],
+)
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
new file mode 100644
index 000000000..c84d8bd7c
--- /dev/null
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -0,0 +1,498 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package route provides a NETLINK_ROUTE socket protocol.
+package route
+
+import (
+ "bytes"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netlink"
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+// commandKind describes the operational class of a message type.
+//
+// The route message types use the lower 2 bits of the type to describe class
+// of command.
+type commandKind int
+
+const (
+ kindNew commandKind = 0x0
+ kindDel = 0x1
+ kindGet = 0x2
+ kindSet = 0x3
+)
+
+func typeKind(typ uint16) commandKind {
+ return commandKind(typ & 0x3)
+}
+
+// Protocol implements netlink.Protocol.
+//
+// +stateify savable
+type Protocol struct{}
+
+var _ netlink.Protocol = (*Protocol)(nil)
+
+// NewProtocol creates a NETLINK_ROUTE netlink.Protocol.
+func NewProtocol(t *kernel.Task) (netlink.Protocol, *syserr.Error) {
+ return &Protocol{}, nil
+}
+
+// Protocol implements netlink.Protocol.Protocol.
+func (p *Protocol) Protocol() int {
+ return linux.NETLINK_ROUTE
+}
+
+// CanSend implements netlink.Protocol.CanSend.
+func (p *Protocol) CanSend() bool {
+ return true
+}
+
+// dumpLinks handles RTM_GETLINK dump requests.
+func (p *Protocol) dumpLinks(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ // NLM_F_DUMP + RTM_GETLINK messages are supposed to include an
+ // ifinfomsg. However, Linux <3.9 only checked for rtgenmsg, and some
+ // userspace applications (including glibc) still include rtgenmsg.
+ // Linux has a workaround based on the total message length.
+ //
+ // We don't bother to check for either, since we don't support any
+ // extra attributes that may be included anyways.
+ //
+ // The message may also contain netlink attribute IFLA_EXT_MASK, which
+ // we don't support.
+
+ // The RTM_GETLINK dump response is a set of messages each containing
+ // an InterfaceInfoMessage followed by a set of netlink attributes.
+
+ // We always send back an NLMSG_DONE.
+ ms.Multi = true
+
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network devices.
+ return nil
+ }
+
+ for idx, i := range stack.Interfaces() {
+ addNewLinkMessage(ms, idx, i)
+ }
+
+ return nil
+}
+
+// getLinks handles RTM_GETLINK requests.
+func (p *Protocol) getLink(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network devices.
+ return nil
+ }
+
+ // Parse message.
+ var ifi linux.InterfaceInfoMessage
+ attrs, ok := msg.GetData(&ifi)
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+
+ // Parse attributes.
+ var byName []byte
+ for !attrs.Empty() {
+ ahdr, value, rest, ok := attrs.ParseFirst()
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+ attrs = rest
+
+ switch ahdr.Type {
+ case linux.IFLA_IFNAME:
+ if len(value) < 1 {
+ return syserr.ErrInvalidArgument
+ }
+ byName = value[:len(value)-1]
+
+ // TODO(gvisor.dev/issue/578): Support IFLA_EXT_MASK.
+ }
+ }
+
+ found := false
+ for idx, i := range stack.Interfaces() {
+ switch {
+ case ifi.Index > 0:
+ if idx != ifi.Index {
+ continue
+ }
+ case byName != nil:
+ if string(byName) != i.Name {
+ continue
+ }
+ default:
+ // Criteria not specified.
+ return syserr.ErrInvalidArgument
+ }
+
+ addNewLinkMessage(ms, idx, i)
+ found = true
+ break
+ }
+ if !found {
+ return syserr.ErrNoDevice
+ }
+ return nil
+}
+
+// addNewLinkMessage appends RTM_NEWLINK message for the given interface into
+// the message set.
+func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.RTM_NEWLINK,
+ })
+
+ m.Put(linux.InterfaceInfoMessage{
+ Family: linux.AF_UNSPEC,
+ Type: i.DeviceType,
+ Index: idx,
+ Flags: i.Flags,
+ })
+
+ m.PutAttrString(linux.IFLA_IFNAME, i.Name)
+ m.PutAttr(linux.IFLA_MTU, i.MTU)
+
+ mac := make([]byte, 6)
+ brd := mac
+ if len(i.Addr) > 0 {
+ mac = i.Addr
+ brd = bytes.Repeat([]byte{0xff}, len(i.Addr))
+ }
+ m.PutAttr(linux.IFLA_ADDRESS, mac)
+ m.PutAttr(linux.IFLA_BROADCAST, brd)
+
+ // TODO(gvisor.dev/issue/578): There are many more attributes.
+}
+
+// dumpAddrs handles RTM_GETADDR dump requests.
+func (p *Protocol) dumpAddrs(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ // RTM_GETADDR dump requests need not contain anything more than the
+ // netlink header and 1 byte protocol family common to all
+ // NETLINK_ROUTE requests.
+ //
+ // TODO(b/68878065): Filter output by passed protocol family.
+
+ // The RTM_GETADDR dump response is a set of RTM_NEWADDR messages each
+ // containing an InterfaceAddrMessage followed by a set of netlink
+ // attributes.
+
+ // We always send back an NLMSG_DONE.
+ ms.Multi = true
+
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network devices.
+ return nil
+ }
+
+ for id, as := range stack.InterfaceAddrs() {
+ for _, a := range as {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.RTM_NEWADDR,
+ })
+
+ m.Put(linux.InterfaceAddrMessage{
+ Family: a.Family,
+ PrefixLen: a.PrefixLen,
+ Index: uint32(id),
+ })
+
+ m.PutAttr(linux.IFA_LOCAL, []byte(a.Addr))
+ m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr))
+
+ // TODO(gvisor.dev/issue/578): There are many more attributes.
+ }
+ }
+
+ return nil
+}
+
+// commonPrefixLen reports the length of the longest IP address prefix.
+// This is a simplied version from Golang's src/net/addrselect.go.
+func commonPrefixLen(a, b []byte) (cpl int) {
+ for len(a) > 0 {
+ if a[0] == b[0] {
+ cpl += 8
+ a = a[1:]
+ b = b[1:]
+ continue
+ }
+ bits := 8
+ ab, bb := a[0], b[0]
+ for {
+ ab >>= 1
+ bb >>= 1
+ bits--
+ if ab == bb {
+ cpl += bits
+ return
+ }
+ }
+ }
+ return
+}
+
+// fillRoute returns the Route using LPM algorithm. Refer to Linux's
+// net/ipv4/route.c:rt_fill_info().
+func fillRoute(routes []inet.Route, addr []byte) (inet.Route, *syserr.Error) {
+ family := uint8(linux.AF_INET)
+ if len(addr) != 4 {
+ family = linux.AF_INET6
+ }
+
+ idx := -1 // Index of the Route rule to be returned.
+ idxDef := -1 // Index of the default route rule.
+ prefix := 0 // Current longest prefix.
+ for i, route := range routes {
+ if route.Family != family {
+ continue
+ }
+
+ if len(route.GatewayAddr) > 0 && route.DstLen == 0 {
+ idxDef = i
+ continue
+ }
+
+ cpl := commonPrefixLen(addr, route.DstAddr)
+ if cpl < int(route.DstLen) {
+ continue
+ }
+ cpl = int(route.DstLen)
+ if cpl > prefix {
+ idx = i
+ prefix = cpl
+ }
+ }
+ if idx == -1 {
+ idx = idxDef
+ }
+ if idx == -1 {
+ return inet.Route{}, syserr.ErrNoRoute
+ }
+
+ route := routes[idx]
+ if family == linux.AF_INET {
+ route.DstLen = 32
+ } else {
+ route.DstLen = 128
+ }
+ route.DstAddr = addr
+ route.Flags |= linux.RTM_F_CLONED // This route is cloned.
+ return route, nil
+}
+
+// parseForDestination parses a message as format of RouteMessage-RtAttr-dst.
+func parseForDestination(msg *netlink.Message) ([]byte, *syserr.Error) {
+ var rtMsg linux.RouteMessage
+ attrs, ok := msg.GetData(&rtMsg)
+ if !ok {
+ return nil, syserr.ErrInvalidArgument
+ }
+ // iproute2 added the RTM_F_LOOKUP_TABLE flag in version v4.4.0. See
+ // commit bc234301af12. Note we don't check this flag for backward
+ // compatibility.
+ if rtMsg.Flags != 0 && rtMsg.Flags != linux.RTM_F_LOOKUP_TABLE {
+ return nil, syserr.ErrNotSupported
+ }
+
+ // Expect first attribute is RTA_DST.
+ if hdr, value, _, ok := attrs.ParseFirst(); ok && hdr.Type == linux.RTA_DST {
+ return value, nil
+ }
+ return nil, syserr.ErrInvalidArgument
+}
+
+// dumpRoutes handles RTM_GETROUTE requests.
+func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ // RTM_GETROUTE dump requests need not contain anything more than the
+ // netlink header and 1 byte protocol family common to all
+ // NETLINK_ROUTE requests.
+
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network routes.
+ return nil
+ }
+
+ hdr := msg.Header()
+ routeTables := stack.RouteTable()
+
+ if hdr.Flags == linux.NLM_F_REQUEST {
+ dst, err := parseForDestination(msg)
+ if err != nil {
+ return err
+ }
+ route, err := fillRoute(routeTables, dst)
+ if err != nil {
+ // TODO(gvisor.dev/issue/1237): return NLMSG_ERROR with ENETUNREACH.
+ return syserr.ErrNotSupported
+ }
+ routeTables = append([]inet.Route{}, route)
+ } else if hdr.Flags&linux.NLM_F_DUMP == linux.NLM_F_DUMP {
+ // We always send back an NLMSG_DONE.
+ ms.Multi = true
+ } else {
+ // TODO(b/68878065): Only above cases are supported.
+ return syserr.ErrNotSupported
+ }
+
+ for _, rt := range routeTables {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.RTM_NEWROUTE,
+ })
+
+ m.Put(linux.RouteMessage{
+ Family: rt.Family,
+ DstLen: rt.DstLen,
+ SrcLen: rt.SrcLen,
+ TOS: rt.TOS,
+
+ // Always return the main table since we don't have multiple
+ // routing tables.
+ Table: linux.RT_TABLE_MAIN,
+ Protocol: rt.Protocol,
+ Scope: rt.Scope,
+ Type: rt.Type,
+
+ Flags: rt.Flags,
+ })
+
+ m.PutAttr(254, []byte{123})
+ if rt.DstLen > 0 {
+ m.PutAttr(linux.RTA_DST, rt.DstAddr)
+ }
+ if rt.SrcLen > 0 {
+ m.PutAttr(linux.RTA_SRC, rt.SrcAddr)
+ }
+ if rt.OutputInterface != 0 {
+ m.PutAttr(linux.RTA_OIF, rt.OutputInterface)
+ }
+ if len(rt.GatewayAddr) > 0 {
+ m.PutAttr(linux.RTA_GATEWAY, rt.GatewayAddr)
+ }
+
+ // TODO(gvisor.dev/issue/578): There are many more attributes.
+ }
+
+ return nil
+}
+
+// newAddr handles RTM_NEWADDR requests.
+func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network stack.
+ return syserr.ErrProtocolNotSupported
+ }
+
+ var ifa linux.InterfaceAddrMessage
+ attrs, ok := msg.GetData(&ifa)
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+
+ for !attrs.Empty() {
+ ahdr, value, rest, ok := attrs.ParseFirst()
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+ attrs = rest
+
+ switch ahdr.Type {
+ case linux.IFA_LOCAL:
+ err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{
+ Family: ifa.Family,
+ PrefixLen: ifa.PrefixLen,
+ Flags: ifa.Flags,
+ Addr: value,
+ })
+ if err == syscall.EEXIST {
+ flags := msg.Header().Flags
+ if flags&linux.NLM_F_EXCL != 0 {
+ return syserr.ErrExists
+ }
+ } else if err != nil {
+ return syserr.ErrInvalidArgument
+ }
+ }
+ }
+ return nil
+}
+
+// ProcessMessage implements netlink.Protocol.ProcessMessage.
+func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ hdr := msg.Header()
+
+ // All messages start with a 1 byte protocol family.
+ var family uint8
+ if _, ok := msg.GetData(&family); !ok {
+ // Linux ignores messages missing the protocol family. See
+ // net/core/rtnetlink.c:rtnetlink_rcv_msg.
+ return nil
+ }
+
+ // Non-GET message types require CAP_NET_ADMIN.
+ if typeKind(hdr.Type) != kindGet {
+ creds := auth.CredentialsFromContext(ctx)
+ if !creds.HasCapability(linux.CAP_NET_ADMIN) {
+ return syserr.ErrPermissionDenied
+ }
+ }
+
+ if hdr.Flags&linux.NLM_F_DUMP == linux.NLM_F_DUMP {
+ // TODO(b/68878065): Only the dump variant of the types below are
+ // supported.
+ switch hdr.Type {
+ case linux.RTM_GETLINK:
+ return p.dumpLinks(ctx, msg, ms)
+ case linux.RTM_GETADDR:
+ return p.dumpAddrs(ctx, msg, ms)
+ case linux.RTM_GETROUTE:
+ return p.dumpRoutes(ctx, msg, ms)
+ default:
+ return syserr.ErrNotSupported
+ }
+ } else if hdr.Flags&linux.NLM_F_REQUEST == linux.NLM_F_REQUEST {
+ switch hdr.Type {
+ case linux.RTM_GETLINK:
+ return p.getLink(ctx, msg, ms)
+ case linux.RTM_GETROUTE:
+ return p.dumpRoutes(ctx, msg, ms)
+ case linux.RTM_NEWADDR:
+ return p.newAddr(ctx, msg, ms)
+ default:
+ return syserr.ErrNotSupported
+ }
+ }
+ return syserr.ErrNotSupported
+}
+
+// init registers the NETLINK_ROUTE provider.
+func init() {
+ netlink.RegisterProvider(linux.NETLINK_ROUTE, NewProtocol)
+}
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
new file mode 100644
index 000000000..81f34c5a2
--- /dev/null
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -0,0 +1,780 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package netlink provides core functionality for netlink sockets.
+package netlink
+
+import (
+ "math"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/device"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+const sizeOfInt32 int = 4
+
+const (
+ // minBufferSize is the smallest size of a send buffer.
+ minSendBufferSize = 4 << 10 // 4096 bytes.
+
+ // defaultSendBufferSize is the default size for the send buffer.
+ defaultSendBufferSize = 16 * 1024
+
+ // maxBufferSize is the largest size a send buffer can grow to.
+ maxSendBufferSize = 4 << 20 // 4MB
+)
+
+var errNoFilter = syserr.New("no filter attached", linux.ENOENT)
+
+// netlinkSocketDevice is the netlink socket virtual device.
+var netlinkSocketDevice = device.NewAnonDevice()
+
+// LINT.IfChange
+
+// Socket is the base socket type for netlink sockets.
+//
+// This implementation only supports userspace sending and receiving messages
+// to/from the kernel.
+//
+// Socket implements socket.Socket and transport.Credentialer.
+//
+// +stateify savable
+type Socket struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ socketOpsCommon
+}
+
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
+ socket.SendReceiveTimeout
+
+ // ports provides netlink port allocation.
+ ports *port.Manager
+
+ // protocol is the netlink protocol implementation.
+ protocol Protocol
+
+ // skType is the socket type. This is either SOCK_DGRAM or SOCK_RAW for
+ // netlink sockets.
+ skType linux.SockType
+
+ // ep is a datagram unix endpoint used to buffer messages sent from the
+ // kernel to userspace. RecvMsg reads messages from this endpoint.
+ ep transport.Endpoint
+
+ // connection is the kernel's connection to ep, used to write messages
+ // sent to userspace.
+ connection transport.ConnectedEndpoint
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // bound indicates that portid is valid.
+ bound bool
+
+ // portID is the port ID allocated for this socket.
+ portID int32
+
+ // sendBufferSize is the send buffer "size". We don't actually have a
+ // fixed buffer but only consume this many bytes.
+ sendBufferSize uint32
+
+ // passcred indicates if this socket wants SCM credentials.
+ passcred bool
+
+ // filter indicates that this socket has a BPF filter "installed".
+ //
+ // TODO(gvisor.dev/issue/1119): We don't actually support filtering,
+ // this is just bookkeeping for tracking add/remove.
+ filter bool
+}
+
+var _ socket.Socket = (*Socket)(nil)
+var _ transport.Credentialer = (*Socket)(nil)
+
+// NewSocket creates a new Socket.
+func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) {
+ // Datagram endpoint used to buffer kernel -> user messages.
+ ep := transport.NewConnectionless(t)
+
+ // Bind the endpoint for good measure so we can connect to it. The
+ // bound address will never be exposed.
+ if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ // Create a connection from which the kernel can write messages.
+ connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t)
+ if err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ return &Socket{
+ socketOpsCommon: socketOpsCommon{
+ ports: t.Kernel().NetlinkPorts(),
+ protocol: protocol,
+ skType: skType,
+ ep: ep,
+ connection: connection,
+ sendBufferSize: defaultSendBufferSize,
+ },
+ }, nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *socketOpsCommon) Release() {
+ s.connection.Release()
+ s.ep.Close()
+
+ if s.bound {
+ s.ports.Release(s.protocol.Protocol(), s.portID)
+ }
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // ep holds messages to be read and thus handles EventIn readiness.
+ ready := s.ep.Readiness(mask)
+
+ if mask&waiter.EventOut == waiter.EventOut {
+ // sendMsg handles messages synchronously and is thus always
+ // ready for writing.
+ ready |= waiter.EventOut
+ }
+
+ return ready
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.ep.EventRegister(e, mask)
+ // Writable readiness never changes, so no registration is needed.
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
+ s.ep.EventUnregister(e)
+}
+
+// Passcred implements transport.Credentialer.Passcred.
+func (s *socketOpsCommon) Passcred() bool {
+ s.mu.Lock()
+ passcred := s.passcred
+ s.mu.Unlock()
+ return passcred
+}
+
+// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
+func (s *socketOpsCommon) ConnectedPasscred() bool {
+ // This socket is connected to the kernel, which doesn't need creds.
+ //
+ // This is arbitrary, as ConnectedPasscred on this type has no callers.
+ return false
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (*Socket) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArguments) (uintptr, error) {
+ // TODO(b/68878065): no ioctls supported.
+ return 0, syserror.ENOTTY
+}
+
+// ExtractSockAddr extracts the SockAddrNetlink from b.
+func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
+ if len(b) < linux.SockAddrNetlinkSize {
+ return nil, syserr.ErrBadAddress
+ }
+
+ var sa linux.SockAddrNetlink
+ binary.Unmarshal(b[:linux.SockAddrNetlinkSize], usermem.ByteOrder, &sa)
+
+ if sa.Family != linux.AF_NETLINK {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return &sa, nil
+}
+
+// bindPort binds this socket to a port, preferring 'port' if it is available.
+//
+// port of 0 defaults to the ThreadGroup ID.
+//
+// Preconditions: mu is held.
+func (s *socketOpsCommon) bindPort(t *kernel.Task, port int32) *syserr.Error {
+ if s.bound {
+ // Re-binding is only allowed if the port doesn't change.
+ if port != s.portID {
+ return syserr.ErrInvalidArgument
+ }
+
+ return nil
+ }
+
+ if port == 0 {
+ port = int32(t.ThreadGroup().ID())
+ }
+ port, ok := s.ports.Allocate(s.protocol.Protocol(), port)
+ if !ok {
+ return syserr.ErrBusy
+ }
+
+ s.portID = port
+ s.bound = true
+ return nil
+}
+
+// Bind implements socket.Socket.Bind.
+func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ a, err := ExtractSockAddr(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ // No support for multicast groups yet.
+ if a.Groups != 0 {
+ return syserr.ErrPermissionDenied
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ return s.bindPort(t, int32(a.PortID))
+}
+
+// Connect implements socket.Socket.Connect.
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ a, err := ExtractSockAddr(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ // No support for multicast groups yet.
+ if a.Groups != 0 {
+ return syserr.ErrPermissionDenied
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if a.PortID == 0 {
+ // Netlink sockets default to connected to the kernel, but
+ // connecting anyways automatically binds if not already bound.
+ if !s.bound {
+ // Pass port 0 to get an auto-selected port ID.
+ return s.bindPort(t, 0)
+ }
+ return nil
+ }
+
+ // We don't support non-kernel destination ports. Linux returns EPERM
+ // if applications attempt to do this without NL_CFG_F_NONROOT_SEND, so
+ // we emulate that.
+ return syserr.ErrPermissionDenied
+}
+
+// Accept implements socket.Socket.Accept.
+func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ // Netlink sockets never support accept.
+ return 0, nil, 0, syserr.ErrNotSupported
+}
+
+// Listen implements socket.Socket.Listen.
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ // Netlink sockets never support listen.
+ return syserr.ErrNotSupported
+}
+
+// Shutdown implements socket.Socket.Shutdown.
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ // Netlink sockets never support shutdown.
+ return syserr.ErrNotSupported
+}
+
+// GetSockOpt implements socket.Socket.GetSockOpt.
+func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+ switch level {
+ case linux.SOL_SOCKET:
+ switch name {
+ case linux.SO_SNDBUF:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ return int32(s.sendBufferSize), nil
+
+ case linux.SO_RCVBUF:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ // We don't have limit on receiving size.
+ return int32(math.MaxInt32), nil
+
+ case linux.SO_PASSCRED:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ var passcred int32
+ if s.Passcred() {
+ passcred = 1
+ }
+ return passcred, nil
+
+ default:
+ socket.GetSockOptEmitUnimplementedEvent(t, name)
+ }
+
+ case linux.SOL_NETLINK:
+ switch name {
+ case linux.NETLINK_BROADCAST_ERROR,
+ linux.NETLINK_CAP_ACK,
+ linux.NETLINK_DUMP_STRICT_CHK,
+ linux.NETLINK_EXT_ACK,
+ linux.NETLINK_LIST_MEMBERSHIPS,
+ linux.NETLINK_NO_ENOBUFS,
+ linux.NETLINK_PKTINFO:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+ }
+ // TODO(b/68878065): other sockopts are not supported.
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// SetSockOpt implements socket.Socket.SetSockOpt.
+func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
+ switch level {
+ case linux.SOL_SOCKET:
+ switch name {
+ case linux.SO_SNDBUF:
+ if len(opt) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ size := usermem.ByteOrder.Uint32(opt)
+ if size < minSendBufferSize {
+ size = minSendBufferSize
+ } else if size > maxSendBufferSize {
+ size = maxSendBufferSize
+ }
+ s.mu.Lock()
+ s.sendBufferSize = size
+ s.mu.Unlock()
+ return nil
+
+ case linux.SO_RCVBUF:
+ if len(opt) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ // We don't have limit on receiving size. So just accept anything as
+ // valid for compatibility.
+ return nil
+
+ case linux.SO_PASSCRED:
+ if len(opt) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ passcred := usermem.ByteOrder.Uint32(opt)
+
+ s.mu.Lock()
+ s.passcred = passcred != 0
+ s.mu.Unlock()
+ return nil
+
+ case linux.SO_ATTACH_FILTER:
+ // TODO(gvisor.dev/issue/1119): We don't actually
+ // support filtering. If this socket can't ever send
+ // messages, then there is nothing to filter and we can
+ // advertise support. Otherwise, be conservative and
+ // return an error.
+ if s.protocol.CanSend() {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ return syserr.ErrProtocolNotAvailable
+ }
+
+ s.mu.Lock()
+ s.filter = true
+ s.mu.Unlock()
+ return nil
+
+ case linux.SO_DETACH_FILTER:
+ // TODO(gvisor.dev/issue/1119): See above.
+ if s.protocol.CanSend() {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ return syserr.ErrProtocolNotAvailable
+ }
+
+ s.mu.Lock()
+ filter := s.filter
+ s.filter = false
+ s.mu.Unlock()
+
+ if !filter {
+ return errNoFilter
+ }
+
+ return nil
+
+ default:
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ }
+
+ case linux.SOL_NETLINK:
+ switch name {
+ case linux.NETLINK_ADD_MEMBERSHIP,
+ linux.NETLINK_BROADCAST_ERROR,
+ linux.NETLINK_CAP_ACK,
+ linux.NETLINK_DROP_MEMBERSHIP,
+ linux.NETLINK_DUMP_STRICT_CHK,
+ linux.NETLINK_EXT_ACK,
+ linux.NETLINK_LISTEN_ALL_NSID,
+ linux.NETLINK_NO_ENOBUFS,
+ linux.NETLINK_PKTINFO:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+
+ }
+ // TODO(b/68878065): other sockopts are not supported.
+ return syserr.ErrProtocolNotAvailable
+}
+
+// GetSockName implements socket.Socket.GetSockName.
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ sa := &linux.SockAddrNetlink{
+ Family: linux.AF_NETLINK,
+ PortID: uint32(s.portID),
+ }
+ return sa, uint32(binary.Size(sa)), nil
+}
+
+// GetPeerName implements socket.Socket.GetPeerName.
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ sa := &linux.SockAddrNetlink{
+ Family: linux.AF_NETLINK,
+ // TODO(b/68878065): Support non-kernel peers. For now the peer
+ // must be the kernel.
+ PortID: 0,
+ }
+ return sa, uint32(binary.Size(sa)), nil
+}
+
+// RecvMsg implements socket.Socket.RecvMsg.
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ from := &linux.SockAddrNetlink{
+ Family: linux.AF_NETLINK,
+ PortID: 0,
+ }
+ fromLen := uint32(binary.Size(from))
+
+ trunc := flags&linux.MSG_TRUNC != 0
+
+ r := unix.EndpointReader{
+ Ctx: t,
+ Endpoint: s.ep,
+ Peek: flags&linux.MSG_PEEK != 0,
+ }
+
+ doRead := func() (int64, error) {
+ return dst.CopyOutFrom(t, &r)
+ }
+
+ // If MSG_TRUNC is set with a zero byte destination then we still need
+ // to read the message and discard it, or in the case where MSG_PEEK is
+ // set, leave it be. In both cases the full message length must be
+ // returned.
+ if trunc && dst.Addrs.NumBytes() == 0 {
+ doRead = func() (int64, error) {
+ err := r.Truncate()
+ // Always return zero for bytes read since the destination size is
+ // zero.
+ return 0, err
+ }
+ }
+
+ if n, err := doRead(); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ var mflags int
+ if n < int64(r.MsgSize) {
+ mflags |= linux.MSG_TRUNC
+ }
+ if trunc {
+ n = int64(r.MsgSize)
+ }
+ return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // receive all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ for {
+ if n, err := doRead(); err != syserror.ErrWouldBlock {
+ var mflags int
+ if n < int64(r.MsgSize) {
+ mflags |= linux.MSG_TRUNC
+ }
+ if trunc {
+ n = int64(r.MsgSize)
+ }
+ return int(n), mflags, from, fromLen, socket.ControlMessages{}, syserr.FromError(err)
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+ }
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &unix.EndpointReader{
+ Endpoint: s.ep,
+ })
+}
+
+// kernelSCM implements control.SCMCredentials with credentials that represent
+// the kernel itself rather than a Task.
+//
+// +stateify savable
+type kernelSCM struct{}
+
+// Equals implements transport.CredentialsControlMessage.Equals.
+func (kernelSCM) Equals(oc transport.CredentialsControlMessage) bool {
+ _, ok := oc.(kernelSCM)
+ return ok
+}
+
+// Credentials implements control.SCMCredentials.Credentials.
+func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) {
+ return 0, auth.RootUID, auth.RootGID
+}
+
+// kernelCreds is the concrete version of kernelSCM used in all creds.
+var kernelCreds = &kernelSCM{}
+
+// sendResponse sends the response messages in ms back to userspace.
+func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error {
+ // Linux combines multiple netlink messages into a single datagram.
+ bufs := make([][]byte, 0, len(ms.Messages))
+ for _, m := range ms.Messages {
+ bufs = append(bufs, m.Finalize())
+ }
+
+ // All messages are from the kernel.
+ cms := transport.ControlMessages{
+ Credentials: kernelCreds,
+ }
+
+ if len(bufs) > 0 {
+ // RecvMsg never receives the address, so we don't need to send
+ // one.
+ _, notify, err := s.connection.Send(bufs, cms, tcpip.FullAddress{})
+ // If the buffer is full, we simply drop messages, just like
+ // Linux.
+ if err != nil && err != syserr.ErrWouldBlock {
+ return err
+ }
+ if notify {
+ s.connection.SendNotify()
+ }
+ }
+
+ // N.B. multi-part messages should still send NLMSG_DONE even if
+ // MessageSet contains no messages.
+ //
+ // N.B. NLMSG_DONE is always sent in a different datagram. See
+ // net/netlink/af_netlink.c:netlink_dump.
+ if ms.Multi {
+ m := NewMessage(linux.NetlinkMessageHeader{
+ Type: linux.NLMSG_DONE,
+ Flags: linux.NLM_F_MULTI,
+ Seq: ms.Seq,
+ PortID: uint32(ms.PortID),
+ })
+
+ // Add the dump_done_errno payload.
+ m.Put(int64(0))
+
+ _, notify, err := s.connection.Send([][]byte{m.Finalize()}, cms, tcpip.FullAddress{})
+ if err != nil && err != syserr.ErrWouldBlock {
+ return err
+ }
+ if notify {
+ s.connection.SendNotify()
+ }
+ }
+
+ return nil
+}
+
+func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.NLMSG_ERROR,
+ })
+ m.Put(linux.NetlinkErrorMessage{
+ Error: int32(-err.ToLinux().Number()),
+ Header: hdr,
+ })
+}
+
+func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.NLMSG_ERROR,
+ })
+ m.Put(linux.NetlinkErrorMessage{
+ Error: 0,
+ Header: hdr,
+ })
+}
+
+// processMessages handles each message in buf, passing it to the protocol
+// handler for final handling.
+func (s *socketOpsCommon) processMessages(ctx context.Context, buf []byte) *syserr.Error {
+ for len(buf) > 0 {
+ msg, rest, ok := ParseMessage(buf)
+ if !ok {
+ // Linux ignores messages that are too short. See
+ // net/netlink/af_netlink.c:netlink_rcv_skb.
+ break
+ }
+ buf = rest
+ hdr := msg.Header()
+
+ // Ignore control messages.
+ if hdr.Type < linux.NLMSG_MIN_TYPE {
+ continue
+ }
+
+ ms := NewMessageSet(s.portID, hdr.Seq)
+ if err := s.protocol.ProcessMessage(ctx, msg, ms); err != nil {
+ dumpErrorMesage(hdr, ms, err)
+ } else if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK {
+ dumpAckMesage(hdr, ms)
+ }
+
+ if err := s.sendResponse(ctx, ms); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// sendMsg is the core of message send, used for SendMsg and Write.
+func (s *socketOpsCommon) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ dstPort := int32(0)
+
+ if len(to) != 0 {
+ a, err := ExtractSockAddr(to)
+ if err != nil {
+ return 0, err
+ }
+
+ // No support for multicast groups yet.
+ if a.Groups != 0 {
+ return 0, syserr.ErrPermissionDenied
+ }
+
+ dstPort = int32(a.PortID)
+ }
+
+ if dstPort != 0 {
+ // Non-kernel destinations not supported yet. Treat as if
+ // NL_CFG_F_NONROOT_SEND is not set.
+ return 0, syserr.ErrPermissionDenied
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // For simplicity, and consistency with Linux, we copy in the entire
+ // message up front.
+ if src.NumBytes() > int64(s.sendBufferSize) {
+ return 0, syserr.ErrMessageTooLong
+ }
+
+ buf := make([]byte, src.NumBytes())
+ n, err := src.CopyIn(ctx, buf)
+ if err != nil {
+ // Don't partially consume messages.
+ return 0, syserr.FromError(err)
+ }
+
+ if err := s.processMessages(ctx, buf); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
+// SendMsg implements socket.Socket.SendMsg.
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ return s.sendMsg(t, src, to, flags, controlMessages)
+}
+
+// Write implements fs.FileOperations.Write.
+func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{})
+ return int64(n), err.ToError()
+}
+
+// State implements socket.Socket.State.
+func (s *socketOpsCommon) State() uint32 {
+ return s.ep.State()
+}
+
+// Type implements socket.Socket.Type.
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
+ return linux.AF_NETLINK, s.skType, s.protocol.Protocol()
+}
+
+// LINT.ThenChange(./socket_vfs2.go)
diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go
new file mode 100644
index 000000000..dbcd8b49a
--- /dev/null
+++ b/pkg/sentry/socket/netlink/socket_vfs2.go
@@ -0,0 +1,152 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netlink
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SocketVFS2 is the base VFS2 socket type for netlink sockets.
+//
+// This implementation only supports userspace sending and receiving messages
+// to/from the kernel.
+//
+// SocketVFS2 implements socket.SocketVFS2 and transport.Credentialer.
+type SocketVFS2 struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
+
+ socketOpsCommon
+}
+
+var _ socket.SocketVFS2 = (*SocketVFS2)(nil)
+var _ transport.Credentialer = (*SocketVFS2)(nil)
+
+// NewVFS2 creates a new SocketVFS2.
+func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketVFS2, *syserr.Error) {
+ // Datagram endpoint used to buffer kernel -> user messages.
+ ep := transport.NewConnectionless(t)
+
+ // Bind the endpoint for good measure so we can connect to it. The
+ // bound address will never be exposed.
+ if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ // Create a connection from which the kernel can write messages.
+ connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t)
+ if err != nil {
+ ep.Close()
+ return nil, err
+ }
+
+ fd := &SocketVFS2{
+ socketOpsCommon: socketOpsCommon{
+ ports: t.Kernel().NetlinkPorts(),
+ protocol: protocol,
+ skType: skType,
+ ep: ep,
+ connection: connection,
+ sendBufferSize: defaultSendBufferSize,
+ },
+ }
+ fd.LockFD.Init(&vfs.FileLocks{})
+ return fd, nil
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.socketOpsCommon.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.socketOpsCommon.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SocketVFS2) EventUnregister(e *waiter.Entry) {
+ s.socketOpsCommon.EventUnregister(e)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (*SocketVFS2) Ioctl(context.Context, usermem.IO, arch.SyscallArguments) (uintptr, error) {
+ // TODO(b/68878065): no ioctls supported.
+ return 0, syserror.ENOTTY
+}
+
+// PRead implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Read implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &unix.EndpointReader{
+ Endpoint: s.ep,
+ })
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{})
+ return int64(n), err.ToError()
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD
new file mode 100644
index 000000000..b6434923c
--- /dev/null
+++ b/pkg/sentry/socket/netlink/uevent/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "uevent",
+ srcs = ["protocol.go"],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/socket/netlink",
+ "//pkg/syserr",
+ ],
+)
diff --git a/pkg/sentry/socket/netlink/uevent/protocol.go b/pkg/sentry/socket/netlink/uevent/protocol.go
new file mode 100644
index 000000000..029ba21b5
--- /dev/null
+++ b/pkg/sentry/socket/netlink/uevent/protocol.go
@@ -0,0 +1,60 @@
+// Copyright 2019 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package uevent provides a NETLINK_KOBJECT_UEVENT socket protocol.
+//
+// NETLINK_KOBJECT_UEVENT sockets send udev-style device events. gVisor does
+// not support any device events, so these sockets never send any messages.
+package uevent
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netlink"
+ "gvisor.dev/gvisor/pkg/syserr"
+)
+
+// Protocol implements netlink.Protocol.
+//
+// +stateify savable
+type Protocol struct{}
+
+var _ netlink.Protocol = (*Protocol)(nil)
+
+// NewProtocol creates a NETLINK_KOBJECT_UEVENT netlink.Protocol.
+func NewProtocol(t *kernel.Task) (netlink.Protocol, *syserr.Error) {
+ return &Protocol{}, nil
+}
+
+// Protocol implements netlink.Protocol.Protocol.
+func (p *Protocol) Protocol() int {
+ return linux.NETLINK_KOBJECT_UEVENT
+}
+
+// CanSend implements netlink.Protocol.CanSend.
+func (p *Protocol) CanSend() bool {
+ return false
+}
+
+// ProcessMessage implements netlink.Protocol.ProcessMessage.
+func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ // Silently ignore all messages.
+ return nil
+}
+
+// init registers the NETLINK_KOBJECT_UEVENT provider.
+func init() {
+ netlink.RegisterProvider(linux.NETLINK_KOBJECT_UEVENT, NewProtocol)
+}
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
new file mode 100644
index 000000000..ea6ebd0e2
--- /dev/null
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -0,0 +1,56 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "netstack",
+ srcs = [
+ "device.go",
+ "netstack.go",
+ "netstack_vfs2.go",
+ "provider.go",
+ "provider_vfs2.go",
+ "save_restore.go",
+ "stack.go",
+ ],
+ visibility = [
+ "//pkg/sentry:internal",
+ ],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/amutex",
+ "//pkg/binary",
+ "//pkg/context",
+ "//pkg/log",
+ "//pkg/metric",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/sockfs",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/netfilter",
+ "//pkg/sentry/unimpl",
+ "//pkg/sentry/vfs",
+ "//pkg/sync",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/header",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ "@org_golang_x_sys//unix:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/socket/netstack/device.go b/pkg/sentry/socket/netstack/device.go
new file mode 100644
index 000000000..fbeb89fb8
--- /dev/null
+++ b/pkg/sentry/socket/netstack/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import "gvisor.dev/gvisor/pkg/sentry/device"
+
+// netstackDevice is the endpoint socket virtual device.
+var netstackDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
new file mode 100644
index 000000000..3b248a953
--- /dev/null
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -0,0 +1,3143 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package netstack provides an implementation of the socket.Socket interface
+// that is backed by a tcpip.Endpoint.
+//
+// It does not depend on any particular endpoint implementation, and thus can
+// be used to expose certain endpoints to the sentry while leaving others out,
+// for example, TCP endpoints and Unix-domain endpoints.
+//
+// Lock ordering: netstack => mm: ioSequencePayload copies user memory inside
+// tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during
+// this operation.
+package netstack
+
+import (
+ "bytes"
+ "io"
+ "math"
+ "reflect"
+ "sync/atomic"
+ "syscall"
+ "time"
+
+ "golang.org/x/sys/unix"
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/amutex"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/metric"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
+ "gvisor.dev/gvisor/pkg/sentry/unimpl"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+func mustCreateMetric(name, description string) *tcpip.StatCounter {
+ var cm tcpip.StatCounter
+ metric.MustRegisterCustomUint64Metric(name, true /* cumulative */, false /* sync */, description, cm.Value)
+ return &cm
+}
+
+func mustCreateGauge(name, description string) *tcpip.StatCounter {
+ var cm tcpip.StatCounter
+ metric.MustRegisterCustomUint64Metric(name, false /* cumulative */, false /* sync */, description, cm.Value)
+ return &cm
+}
+
+// Metrics contains metrics exported by netstack.
+var Metrics = tcpip.Stats{
+ UnknownProtocolRcvdPackets: mustCreateMetric("/netstack/unknown_protocol_received_packets", "Number of packets received by netstack that were for an unknown or unsupported protocol."),
+ MalformedRcvdPackets: mustCreateMetric("/netstack/malformed_received_packets", "Number of packets received by netstack that were deemed malformed."),
+ DroppedPackets: mustCreateMetric("/netstack/dropped_packets", "Number of packets dropped by netstack due to full queues."),
+ ICMP: tcpip.ICMPStats{
+ V4PacketsSent: tcpip.ICMPv4SentPacketStats{
+ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
+ Echo: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo", "Total number of ICMPv4 echo packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/echo_reply", "Total number of ICMPv4 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_sent/dst_unreachable", "Total number of ICMPv4 destination unreachable packets sent by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_sent/src_quench", "Total number of ICMPv4 source quench packets sent by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_sent/redirect", "Total number of ICMPv4 redirect packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_sent/time_exceeded", "Total number of ICMPv4 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_sent/param_problem", "Total number of ICMPv4 parameter problem packets sent by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp", "Total number of ICMPv4 timestamp packets sent by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/timestamp_reply", "Total number of ICMPv4 timestamp reply packets sent by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_request", "Total number of ICMPv4 information request packets sent by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_sent/info_reply", "Total number of ICMPv4 information reply packets sent by netstack."),
+ },
+ Dropped: mustCreateMetric("/netstack/icmp/v4/packets_sent/dropped", "Total number of ICMPv4 packets dropped by netstack due to link layer errors."),
+ },
+ V4PacketsReceived: tcpip.ICMPv4ReceivedPacketStats{
+ ICMPv4PacketStats: tcpip.ICMPv4PacketStats{
+ Echo: mustCreateMetric("/netstack/icmp/v4/packets_received/echo", "Total number of ICMPv4 echo packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/echo_reply", "Total number of ICMPv4 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v4/packets_received/dst_unreachable", "Total number of ICMPv4 destination unreachable packets received by netstack."),
+ SrcQuench: mustCreateMetric("/netstack/icmp/v4/packets_received/src_quench", "Total number of ICMPv4 source quench packets received by netstack."),
+ Redirect: mustCreateMetric("/netstack/icmp/v4/packets_received/redirect", "Total number of ICMPv4 redirect packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v4/packets_received/time_exceeded", "Total number of ICMPv4 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v4/packets_received/param_problem", "Total number of ICMPv4 parameter problem packets received by netstack."),
+ Timestamp: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp", "Total number of ICMPv4 timestamp packets received by netstack."),
+ TimestampReply: mustCreateMetric("/netstack/icmp/v4/packets_received/timestamp_reply", "Total number of ICMPv4 timestamp reply packets received by netstack."),
+ InfoRequest: mustCreateMetric("/netstack/icmp/v4/packets_received/info_request", "Total number of ICMPv4 information request packets received by netstack."),
+ InfoReply: mustCreateMetric("/netstack/icmp/v4/packets_received/info_reply", "Total number of ICMPv4 information reply packets received by netstack."),
+ },
+ Invalid: mustCreateMetric("/netstack/icmp/v4/packets_received/invalid", "Total number of ICMPv4 packets received that the transport layer could not parse."),
+ },
+ V6PacketsSent: tcpip.ICMPv6SentPacketStats{
+ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_request", "Total number of ICMPv6 echo request packets sent by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_sent/echo_reply", "Total number of ICMPv6 echo reply packets sent by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_sent/dst_unreachable", "Total number of ICMPv6 destination unreachable packets sent by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_sent/packet_too_big", "Total number of ICMPv6 packet too big packets sent by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_sent/time_exceeded", "Total number of ICMPv6 time exceeded packets sent by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_sent/param_problem", "Total number of ICMPv6 parameter problem packets sent by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_solicit", "Total number of ICMPv6 router solicit packets sent by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/router_advert", "Total number of ICMPv6 router advert packets sent by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets sent by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_sent/neighbor_advert", "Total number of ICMPv6 neighbor advert packets sent by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_sent/redirect_msg", "Total number of ICMPv6 redirect message packets sent by netstack."),
+ },
+ Dropped: mustCreateMetric("/netstack/icmp/v6/packets_sent/dropped", "Total number of ICMPv6 packets dropped by netstack due to link layer errors."),
+ },
+ V6PacketsReceived: tcpip.ICMPv6ReceivedPacketStats{
+ ICMPv6PacketStats: tcpip.ICMPv6PacketStats{
+ EchoRequest: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_request", "Total number of ICMPv6 echo request packets received by netstack."),
+ EchoReply: mustCreateMetric("/netstack/icmp/v6/packets_received/echo_reply", "Total number of ICMPv6 echo reply packets received by netstack."),
+ DstUnreachable: mustCreateMetric("/netstack/icmp/v6/packets_received/dst_unreachable", "Total number of ICMPv6 destination unreachable packets received by netstack."),
+ PacketTooBig: mustCreateMetric("/netstack/icmp/v6/packets_received/packet_too_big", "Total number of ICMPv6 packet too big packets received by netstack."),
+ TimeExceeded: mustCreateMetric("/netstack/icmp/v6/packets_received/time_exceeded", "Total number of ICMPv6 time exceeded packets received by netstack."),
+ ParamProblem: mustCreateMetric("/netstack/icmp/v6/packets_received/param_problem", "Total number of ICMPv6 parameter problem packets received by netstack."),
+ RouterSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/router_solicit", "Total number of ICMPv6 router solicit packets received by netstack."),
+ RouterAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/router_advert", "Total number of ICMPv6 router advert packets received by netstack."),
+ NeighborSolicit: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_solicit", "Total number of ICMPv6 neighbor solicit packets received by netstack."),
+ NeighborAdvert: mustCreateMetric("/netstack/icmp/v6/packets_received/neighbor_advert", "Total number of ICMPv6 neighbor advert packets received by netstack."),
+ RedirectMsg: mustCreateMetric("/netstack/icmp/v6/packets_received/redirect_msg", "Total number of ICMPv6 redirect message packets received by netstack."),
+ },
+ Invalid: mustCreateMetric("/netstack/icmp/v6/packets_received/invalid", "Total number of ICMPv6 packets received that the transport layer could not parse."),
+ },
+ },
+ IP: tcpip.IPStats{
+ PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."),
+ InvalidDestinationAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."),
+ InvalidSourceAddressesReceived: mustCreateMetric("/netstack/ip/invalid_source_addresses_received", "Total number of IP packets received with an unknown or invalid source address."),
+ PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."),
+ PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."),
+ OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."),
+ MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."),
+ MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."),
+ },
+ TCP: tcpip.TCPStats{
+ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."),
+ PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."),
+ CurrentEstablished: mustCreateGauge("/netstack/tcp/current_established", "Number of connections in ESTABLISHED state now."),
+ CurrentConnected: mustCreateGauge("/netstack/tcp/current_open", "Number of connections that are in connected state."),
+ EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"),
+ EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "Number of times established TCP connections made a transition to CLOSED state."),
+ EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."),
+ ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."),
+ ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."),
+ ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."),
+ ListenOverflowSynCookieRcvd: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_rcvd", "Number of times a SYN cookie was received."),
+ ListenOverflowInvalidSynCookieRcvd: mustCreateMetric("/netstack/tcp/listen_overflow_invalid_syn_cookie_rcvd", "Number of times an invalid SYN cookie was received."),
+ FailedConnectionAttempts: mustCreateMetric("/netstack/tcp/failed_connection_attempts", "Number of calls to Connect or Listen (active and passive openings, respectively) that end in an error."),
+ ValidSegmentsReceived: mustCreateMetric("/netstack/tcp/valid_segments_received", "Number of TCP segments received that the transport layer successfully parsed."),
+ InvalidSegmentsReceived: mustCreateMetric("/netstack/tcp/invalid_segments_received", "Number of TCP segments received that the transport layer could not parse."),
+ SegmentsSent: mustCreateMetric("/netstack/tcp/segments_sent", "Number of TCP segments sent."),
+ SegmentSendErrors: mustCreateMetric("/netstack/tcp/segment_send_errors", "Number of TCP segments failed to be sent."),
+ ResetsSent: mustCreateMetric("/netstack/tcp/resets_sent", "Number of TCP resets sent."),
+ ResetsReceived: mustCreateMetric("/netstack/tcp/resets_received", "Number of TCP resets received."),
+ Retransmits: mustCreateMetric("/netstack/tcp/retransmits", "Number of TCP segments retransmitted."),
+ FastRecovery: mustCreateMetric("/netstack/tcp/fast_recovery", "Number of times fast recovery was used to recover from packet loss."),
+ SACKRecovery: mustCreateMetric("/netstack/tcp/sack_recovery", "Number of times SACK recovery was used to recover from packet loss."),
+ SlowStartRetransmits: mustCreateMetric("/netstack/tcp/slow_start_retransmits", "Number of segments retransmitted in slow start mode."),
+ FastRetransmit: mustCreateMetric("/netstack/tcp/fast_retransmit", "Number of TCP segments which were fast retransmitted."),
+ Timeouts: mustCreateMetric("/netstack/tcp/timeouts", "Number of times RTO expired."),
+ ChecksumErrors: mustCreateMetric("/netstack/tcp/checksum_errors", "Number of segments dropped due to bad checksums."),
+ },
+ UDP: tcpip.UDPStats{
+ PacketsReceived: mustCreateMetric("/netstack/udp/packets_received", "Number of UDP datagrams received via HandlePacket."),
+ UnknownPortErrors: mustCreateMetric("/netstack/udp/unknown_port_errors", "Number of incoming UDP datagrams dropped because they did not have a known destination port."),
+ ReceiveBufferErrors: mustCreateMetric("/netstack/udp/receive_buffer_errors", "Number of incoming UDP datagrams dropped due to the receiving buffer being in an invalid state."),
+ MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."),
+ PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."),
+ PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."),
+ ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."),
+ },
+}
+
+// DefaultTTL is linux's default TTL. All network protocols in all stacks used
+// with this package must have this value set as their default TTL.
+const DefaultTTL = 64
+
+const sizeOfInt32 int = 4
+
+var errStackType = syserr.New("expected but did not receive a netstack.Stack", linux.EINVAL)
+
+// ntohs converts a 16-bit number from network byte order to host byte order. It
+// assumes that the host is little endian.
+func ntohs(v uint16) uint16 {
+ return v<<8 | v>>8
+}
+
+// htons converts a 16-bit number from host byte order to network byte order. It
+// assumes that the host is little endian.
+func htons(v uint16) uint16 {
+ return ntohs(v)
+}
+
+// commonEndpoint represents the intersection of a tcpip.Endpoint and a
+// transport.Endpoint.
+type commonEndpoint interface {
+ // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress and
+ // transport.Endpoint.GetLocalAddress.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress and
+ // transport.Endpoint.GetRemoteAddress.
+ GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Readiness implements tcpip.Endpoint.Readiness and
+ // transport.Endpoint.Readiness.
+ Readiness(mask waiter.EventMask) waiter.EventMask
+
+ // SetSockOpt implements tcpip.Endpoint.SetSockOpt and
+ // transport.Endpoint.SetSockOpt.
+ SetSockOpt(interface{}) *tcpip.Error
+
+ // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and
+ // transport.Endpoint.SetSockOptBool.
+ SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
+
+ // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and
+ // transport.Endpoint.SetSockOptInt.
+ SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
+
+ // GetSockOpt implements tcpip.Endpoint.GetSockOpt and
+ // transport.Endpoint.GetSockOpt.
+ GetSockOpt(interface{}) *tcpip.Error
+
+ // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and
+ // transport.Endpoint.GetSockOpt.
+ GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
+
+ // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
+ // transport.Endpoint.GetSockOpt.
+ GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
+}
+
+// LINT.IfChange
+
+// SocketOperations encapsulates all the state needed to represent a network stack
+// endpoint in the kernel context.
+//
+// +stateify savable
+type SocketOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ socketOpsCommon
+}
+
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
+ socket.SendReceiveTimeout
+ *waiter.Queue
+
+ family int
+ Endpoint tcpip.Endpoint
+ skType linux.SockType
+ protocol int
+
+ // readViewHasData is 1 iff readView has data to be read, 0 otherwise.
+ // Must be accessed using atomic operations. It must only be written
+ // with readMu held but can be read without holding readMu. The latter
+ // is required to avoid deadlocks in epoll Readiness checks.
+ readViewHasData uint32
+
+ // readMu protects access to the below fields.
+ readMu sync.Mutex `state:"nosave"`
+ // readView contains the remaining payload from the last packet.
+ readView buffer.View
+ // readCM holds control message information for the last packet read
+ // from Endpoint.
+ readCM tcpip.ControlMessages
+ sender tcpip.FullAddress
+
+ // sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps
+ // of returned messages can be returned via control messages. When
+ // false, the same timestamp is instead stored and can be read via the
+ // SIOCGSTAMP ioctl. It is protected by readMu. See socket(7).
+ sockOptTimestamp bool
+ // timestampValid indicates whether timestamp for SIOCGSTAMP has been
+ // set. It is protected by readMu.
+ timestampValid bool
+ // timestampNS holds the timestamp to use with SIOCTSTAMP. It is only
+ // valid when timestampValid is true. It is protected by readMu.
+ timestampNS int64
+
+ // sockOptInq corresponds to TCP_INQ. It is implemented at this level
+ // because it takes into account data from readView.
+ sockOptInq bool
+}
+
+// New creates a new endpoint socket.
+func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
+ if skType == linux.SOCK_STREAM {
+ if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ }
+
+ dirent := socket.NewDirent(t, netstackDevice)
+ defer dirent.DecRef()
+ return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, &SocketOperations{
+ socketOpsCommon: socketOpsCommon{
+ Queue: queue,
+ family: family,
+ Endpoint: endpoint,
+ skType: skType,
+ protocol: protocol,
+ },
+ }), nil
+}
+
+var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{}))
+var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{}))
+var sockAddrLinkSize = int(binary.Size(linux.SockAddrLink{}))
+
+// bytesToIPAddress converts an IPv4 or IPv6 address from the user to the
+// netstack representation taking any addresses into account.
+func bytesToIPAddress(addr []byte) tcpip.Address {
+ if bytes.Equal(addr, make([]byte, 4)) || bytes.Equal(addr, make([]byte, 16)) {
+ return ""
+ }
+ return tcpip.Address(addr)
+}
+
+// AddressAndFamily reads an sockaddr struct from the given address and
+// converts it to the FullAddress format. It supports AF_UNIX, AF_INET,
+// AF_INET6, and AF_PACKET addresses.
+//
+// AddressAndFamily returns an address and its family.
+func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) {
+ // Make sure we have at least 2 bytes for the address family.
+ if len(addr) < 2 {
+ return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument
+ }
+
+ // Get the rest of the fields based on the address family.
+ switch family := usermem.ByteOrder.Uint16(addr); family {
+ case linux.AF_UNIX:
+ path := addr[2:]
+ if len(path) > linux.UnixPathMax {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ // Drop the terminating NUL (if one exists) and everything after
+ // it for filesystem (non-abstract) addresses.
+ if len(path) > 0 && path[0] != 0 {
+ if n := bytes.IndexByte(path[1:], 0); n >= 0 {
+ path = path[:n+1]
+ }
+ }
+ return tcpip.FullAddress{
+ Addr: tcpip.Address(path),
+ }, family, nil
+
+ case linux.AF_INET:
+ var a linux.SockAddrInet
+ if len(addr) < sockAddrInetSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: bytesToIPAddress(a.Addr[:]),
+ Port: ntohs(a.Port),
+ }
+ return out, family, nil
+
+ case linux.AF_INET6:
+ var a linux.SockAddrInet6
+ if len(addr) < sockAddrInet6Size {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: bytesToIPAddress(a.Addr[:]),
+ Port: ntohs(a.Port),
+ }
+ if isLinkLocal(out.Addr) {
+ out.NIC = tcpip.NICID(a.Scope_id)
+ }
+ return out, family, nil
+
+ case linux.AF_PACKET:
+ var a linux.SockAddrLink
+ if len(addr) < sockAddrLinkSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(addr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+ if a.Family != linux.AF_PACKET || a.HardwareAddrLen != header.EthernetAddressSize {
+ return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument
+ }
+
+ // TODO(b/129292371): Return protocol too.
+ return tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }, family, nil
+
+ case linux.AF_UNSPEC:
+ return tcpip.FullAddress{}, family, nil
+
+ default:
+ return tcpip.FullAddress{}, 0, syserr.ErrAddressFamilyNotSupported
+ }
+}
+
+func (s *socketOpsCommon) isPacketBased() bool {
+ return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW
+}
+
+// fetchReadView updates the readView field of the socket if it's currently
+// empty. It assumes that the socket is locked.
+//
+// Precondition: s.readMu must be held.
+func (s *socketOpsCommon) fetchReadView() *syserr.Error {
+ if len(s.readView) > 0 {
+ return nil
+ }
+ s.readView = nil
+ s.sender = tcpip.FullAddress{}
+
+ v, cms, err := s.Endpoint.Read(&s.sender)
+ if err != nil {
+ atomic.StoreUint32(&s.readViewHasData, 0)
+ return syserr.TranslateNetstackError(err)
+ }
+
+ s.readView = v
+ s.readCM = cms
+ atomic.StoreUint32(&s.readViewHasData, 1)
+
+ return nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *socketOpsCommon) Release() {
+ s.Endpoint.Close()
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
+ if err == syserr.ErrWouldBlock {
+ return int64(n), syserror.ErrWouldBlock
+ }
+ if err != nil {
+ return 0, err.ToError()
+ }
+ return int64(n), nil
+}
+
+// WriteTo implements fs.FileOperations.WriteTo.
+func (s *SocketOperations) WriteTo(ctx context.Context, _ *fs.File, dst io.Writer, count int64, dup bool) (int64, error) {
+ s.readMu.Lock()
+
+ // Copy as much data as possible.
+ done := int64(0)
+ for count > 0 {
+ // This may return a blocking error.
+ if err := s.fetchReadView(); err != nil {
+ s.readMu.Unlock()
+ return done, err.ToError()
+ }
+
+ // Write to the underlying file.
+ n, err := dst.Write(s.readView)
+ done += int64(n)
+ count -= int64(n)
+ if dup {
+ // That's all we support for dup. This is generally
+ // supported by any Linux system calls, but the
+ // expectation is that now a caller will call read to
+ // actually remove these bytes from the socket.
+ break
+ }
+
+ // Drop that part of the view.
+ s.readView.TrimFront(n)
+ if err != nil {
+ s.readMu.Unlock()
+ return done, err
+ }
+ }
+
+ s.readMu.Unlock()
+ return done, nil
+}
+
+// ioSequencePayload implements tcpip.Payload.
+//
+// t copies user memory bytes on demand based on the requested size.
+type ioSequencePayload struct {
+ ctx context.Context
+ src usermem.IOSequence
+}
+
+// FullPayload implements tcpip.Payloader.FullPayload
+func (i *ioSequencePayload) FullPayload() ([]byte, *tcpip.Error) {
+ return i.Payload(int(i.src.NumBytes()))
+}
+
+// Payload implements tcpip.Payloader.Payload.
+func (i *ioSequencePayload) Payload(size int) ([]byte, *tcpip.Error) {
+ if max := int(i.src.NumBytes()); size > max {
+ size = max
+ }
+ v := buffer.NewView(size)
+ if _, err := i.src.CopyIn(i.ctx, v); err != nil {
+ return nil, tcpip.ErrBadAddress
+ }
+ return v, nil
+}
+
+// DropFirst drops the first n bytes from underlying src.
+func (i *ioSequencePayload) DropFirst(n int) {
+ i.src = i.src.DropFirst(int(n))
+}
+
+// Write implements fs.FileOperations.Write.
+func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ f := &ioSequencePayload{ctx: ctx, src: src}
+ n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
+ if err == tcpip.ErrWouldBlock {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ if resCh != nil {
+ if err := amutex.Block(ctx, resCh); err != nil {
+ return 0, err
+ }
+ n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{})
+ }
+
+ if err != nil {
+ return 0, syserr.TranslateNetstackError(err).ToError()
+ }
+
+ if int64(n) < src.NumBytes() {
+ return int64(n), syserror.ErrWouldBlock
+ }
+
+ return int64(n), nil
+}
+
+// readerPayload implements tcpip.Payloader.
+//
+// It allocates a view and reads from a reader on-demand, based on available
+// capacity in the endpoint.
+type readerPayload struct {
+ ctx context.Context
+ r io.Reader
+ count int64
+ err error
+}
+
+// FullPayload implements tcpip.Payloader.FullPayload.
+func (r *readerPayload) FullPayload() ([]byte, *tcpip.Error) {
+ return r.Payload(int(r.count))
+}
+
+// Payload implements tcpip.Payloader.Payload.
+func (r *readerPayload) Payload(size int) ([]byte, *tcpip.Error) {
+ if size > int(r.count) {
+ size = int(r.count)
+ }
+ v := buffer.NewView(size)
+ n, err := r.r.Read(v)
+ if n > 0 {
+ // We ignore the error here. It may re-occur on subsequent
+ // reads, but for now we can enqueue some amount of data.
+ r.count -= int64(n)
+ return v[:n], nil
+ }
+ if err == syserror.ErrWouldBlock {
+ return nil, tcpip.ErrWouldBlock
+ } else if err != nil {
+ r.err = err // Save for propation.
+ return nil, tcpip.ErrBadAddress
+ }
+
+ // There is no data and no error. Return an error, which will propagate
+ // r.err, which will be nil. This is the desired result: (0, nil).
+ return nil, tcpip.ErrBadAddress
+}
+
+// ReadFrom implements fs.FileOperations.ReadFrom.
+func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader, count int64) (int64, error) {
+ f := &readerPayload{ctx: ctx, r: r, count: count}
+ n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{
+ // Reads may be destructive but should be very fast,
+ // so we can't release the lock while copying data.
+ Atomic: true,
+ })
+ if err == tcpip.ErrWouldBlock {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ if resCh != nil {
+ if err := amutex.Block(ctx, resCh); err != nil {
+ return 0, err
+ }
+ n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{
+ Atomic: true, // See above.
+ })
+ }
+ if err == tcpip.ErrWouldBlock {
+ return n, syserror.ErrWouldBlock
+ } else if err != nil {
+ return int64(n), f.err // Propagate error.
+ }
+
+ return int64(n), nil
+}
+
+// Readiness returns a mask of ready events for socket s.
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
+ r := s.Endpoint.Readiness(mask)
+
+ // Check our cached value iff the caller asked for readability and the
+ // endpoint itself is currently not readable.
+ if (mask & ^r & waiter.EventIn) != 0 {
+ if atomic.LoadUint32(&s.readViewHasData) == 1 {
+ r |= waiter.EventIn
+ }
+ }
+
+ return r
+}
+
+func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error {
+ if family == uint16(s.family) {
+ return nil
+ }
+ if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 {
+ v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption)
+ if err != nil {
+ return syserr.TranslateNetstackError(err)
+ }
+ if !v {
+ return nil
+ }
+ }
+ return syserr.ErrInvalidArgument
+}
+
+// mapFamily maps the AF_INET ANY address to the IPv4-mapped IPv6 ANY if the
+// receiver's family is AF_INET6.
+//
+// This is a hack to work around the fact that both IPv4 and IPv6 ANY are
+// represented by the empty string.
+//
+// TODO(gvisor.dev/issue/1556): remove this function.
+func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress {
+ if len(addr.Addr) == 0 && s.family == linux.AF_INET6 && family == linux.AF_INET {
+ addr.Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
+ }
+ return addr
+}
+
+// Connect implements the linux syscall connect(2) for sockets backed by
+// tpcip.Endpoint.
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ addr, family, err := AddressAndFamily(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ if family == linux.AF_UNSPEC {
+ err := s.Endpoint.Disconnect()
+ if err == tcpip.ErrNotSupported {
+ return syserr.ErrAddressFamilyNotSupported
+ }
+ return syserr.TranslateNetstackError(err)
+ }
+
+ if err := s.checkFamily(family, false /* exact */); err != nil {
+ return err
+ }
+ addr = s.mapFamily(addr, family)
+
+ // Always return right away in the non-blocking case.
+ if !blocking {
+ return syserr.TranslateNetstackError(s.Endpoint.Connect(addr))
+ }
+
+ // Register for notification when the endpoint becomes writable, then
+ // initiate the connection.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+
+ if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting {
+ if (s.family == unix.AF_INET || s.family == unix.AF_INET6) && s.skType == linux.SOCK_STREAM {
+ // TCP unlike UDP returns EADDRNOTAVAIL when it can't
+ // find an available local ephemeral port.
+ if err == tcpip.ErrNoPortAvailable {
+ return syserr.ErrAddressNotAvailable
+ }
+ }
+
+ return syserr.TranslateNetstackError(err)
+ }
+
+ // It's pending, so we have to wait for a notification, and fetch the
+ // result once the wait completes.
+ if err := t.Block(ch); err != nil {
+ return syserr.FromError(err)
+ }
+
+ // Call Connect() again after blocking to find connect's result.
+ return syserr.TranslateNetstackError(s.Endpoint.Connect(addr))
+}
+
+// Bind implements the linux syscall bind(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ if len(sockaddr) < 2 {
+ return syserr.ErrInvalidArgument
+ }
+
+ family := usermem.ByteOrder.Uint16(sockaddr)
+ var addr tcpip.FullAddress
+
+ // Bind for AF_PACKET requires only family, protocol and ifindex.
+ // In function AddressAndFamily, we check the address length which is
+ // not needed for AF_PACKET bind.
+ if family == linux.AF_PACKET {
+ var a linux.SockAddrLink
+ if len(sockaddr) < sockAddrLinkSize {
+ return syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(sockaddr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+
+ if a.Protocol != uint16(s.protocol) {
+ return syserr.ErrInvalidArgument
+ }
+
+ addr = tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }
+ } else {
+ var err *syserr.Error
+ addr, family, err = AddressAndFamily(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ if err = s.checkFamily(family, true /* exact */); err != nil {
+ return err
+ }
+
+ addr = s.mapFamily(addr, family)
+ }
+
+ // Issue the bind request to the endpoint.
+ return syserr.TranslateNetstackError(s.Endpoint.Bind(addr))
+}
+
+// Listen implements the linux syscall listen(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ return syserr.TranslateNetstackError(s.Endpoint.Listen(backlog))
+}
+
+// blockingAccept implements a blocking version of accept(2), that is, if no
+// connections are ready to be accept, it will block until one becomes ready.
+func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ // Try to accept the connection again; if it fails, then wait until we
+ // get a notification.
+ for {
+ if ep, wq, err := s.Endpoint.Accept(); err != tcpip.ErrWouldBlock {
+ return ep, wq, syserr.TranslateNetstackError(err)
+ }
+
+ if err := t.Block(ch); err != nil {
+ return nil, nil, syserr.FromError(err)
+ }
+ }
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, wq, terr := s.Endpoint.Accept()
+ if terr != nil {
+ if terr != tcpip.ErrWouldBlock || !blocking {
+ return 0, nil, 0, syserr.TranslateNetstackError(terr)
+ }
+
+ var err *syserr.Error
+ ep, wq, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ ns, err := New(t, s.family, s.skType, s.protocol, wq, ep)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ defer ns.DecRef()
+
+ if flags&linux.SOCK_NONBLOCK != 0 {
+ flags := ns.Flags()
+ flags.NonBlocking = true
+ ns.SetFlags(flags.Settable())
+ }
+
+ var addr linux.SockAddr
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer and write it to peer slice.
+ var err *syserr.Error
+ addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ })
+
+ t.Kernel().RecordSocket(ns)
+
+ return fd, addr, addrLen, syserr.FromError(e)
+}
+
+// ConvertShutdown converts Linux shutdown flags into tcpip shutdown flags.
+func ConvertShutdown(how int) (tcpip.ShutdownFlags, *syserr.Error) {
+ var f tcpip.ShutdownFlags
+ switch how {
+ case linux.SHUT_RD:
+ f = tcpip.ShutdownRead
+ case linux.SHUT_WR:
+ f = tcpip.ShutdownWrite
+ case linux.SHUT_RDWR:
+ f = tcpip.ShutdownRead | tcpip.ShutdownWrite
+ default:
+ return 0, syserr.ErrInvalidArgument
+ }
+ return f, nil
+}
+
+// Shutdown implements the linux syscall shutdown(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ f, err := ConvertShutdown(how)
+ if err != nil {
+ return err
+ }
+
+ // Issue shutdown request.
+ return syserr.TranslateNetstackError(s.Endpoint.Shutdown(f))
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+ // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
+ // implemented specifically for netstack.SocketOperations rather than
+ // commonEndpoint. commonEndpoint should be extended to support socket
+ // options where the implementation is not shared, as unix sockets need
+ // their own support for SO_TIMESTAMP.
+ if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP {
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ val := int32(0)
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if s.sockOptTimestamp {
+ val = 1
+ }
+ return val, nil
+ }
+ if level == linux.SOL_TCP && name == linux.TCP_INQ {
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ val := int32(0)
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if s.sockOptInq {
+ val = 1
+ }
+ return val, nil
+ }
+
+ if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
+ switch name {
+ case linux.IPT_SO_GET_INFO:
+ if outLen < linux.SizeOfIPTGetinfo {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr)
+ if err != nil {
+ return nil, err
+ }
+ return info, nil
+
+ case linux.IPT_SO_GET_ENTRIES:
+ if outLen < linux.SizeOfIPTGetEntries {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen)
+ if err != nil {
+ return nil, err
+ }
+ return entries, nil
+
+ }
+ }
+
+ return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen)
+}
+
+// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
+// sockets backed by a commonEndpoint.
+func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+ switch level {
+ case linux.SOL_SOCKET:
+ return getSockOptSocket(t, s, ep, family, skType, name, outLen)
+
+ case linux.SOL_TCP:
+ return getSockOptTCP(t, ep, name, outLen)
+
+ case linux.SOL_IPV6:
+ return getSockOptIPv6(t, ep, name, outLen)
+
+ case linux.SOL_IP:
+ return getSockOptIP(t, ep, name, outLen, family)
+
+ case linux.SOL_UDP,
+ linux.SOL_ICMPV6,
+ linux.SOL_RAW,
+ linux.SOL_PACKET:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+func boolToInt32(v bool) int32 {
+ if v {
+ return 1
+ }
+ return 0
+}
+
+// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
+ // TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
+ switch name {
+ case linux.SO_ERROR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Get the last error and convert it.
+ err := ep.GetSockOpt(tcpip.ErrorOption{})
+ if err == nil {
+ return int32(0), nil
+ }
+ return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil
+
+ case linux.SO_PEERCRED:
+ if family != linux.AF_UNIX || outLen < syscall.SizeofUcred {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ tcred := t.Credentials()
+ return syscall.Ucred{
+ Pid: int32(t.ThreadGroup().ID()),
+ Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
+ Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
+ }, nil
+
+ case linux.SO_PASSCRED:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.PasscredOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.SO_SNDBUF:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ size, err := ep.GetSockOptInt(tcpip.SendBufferSizeOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ if size > math.MaxInt32 {
+ size = math.MaxInt32
+ }
+
+ return int32(size), nil
+
+ case linux.SO_RCVBUF:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ size, err := ep.GetSockOptInt(tcpip.ReceiveBufferSizeOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ if size > math.MaxInt32 {
+ size = math.MaxInt32
+ }
+
+ return int32(size), nil
+
+ case linux.SO_REUSEADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.SO_REUSEPORT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReusePortOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.SO_BINDTODEVICE:
+ var v tcpip.BindToDeviceOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ if v == 0 {
+ return []byte{}, nil
+ }
+ if outLen < linux.IFNAMSIZ {
+ return nil, syserr.ErrInvalidArgument
+ }
+ s := t.NetworkContext()
+ if s == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ nic, ok := s.Interfaces()[int32(v)]
+ if !ok {
+ // The NICID no longer indicates a valid interface, probably because that
+ // interface was removed.
+ return nil, syserr.ErrUnknownDevice
+ }
+ return append([]byte(nic.Name), 0), nil
+
+ case linux.SO_BROADCAST:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.BroadcastOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.SO_KEEPALIVE:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.SO_LINGER:
+ if outLen < linux.SizeOfLinger {
+ return nil, syserr.ErrInvalidArgument
+ }
+ return linux.Linger{}, nil
+
+ case linux.SO_SNDTIMEO:
+ // TODO(igudger): Linux allows shorter lengths for partial results.
+ if outLen < linux.SizeOfTimeval {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return linux.NsecToTimeval(s.SendTimeout()), nil
+
+ case linux.SO_RCVTIMEO:
+ // TODO(igudger): Linux allows shorter lengths for partial results.
+ if outLen < linux.SizeOfTimeval {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return linux.NsecToTimeval(s.RecvTimeout()), nil
+
+ case linux.SO_OOBINLINE:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.OutOfBandInlineOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.SO_NO_CHECK:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.NoChecksumOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ default:
+ socket.GetSockOptEmitUnimplementedEvent(t, name)
+ }
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// getSockOptTCP implements GetSockOpt when level is SOL_TCP.
+func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+ switch name {
+ case linux.TCP_NODELAY:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.DelayOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(!v), nil
+
+ case linux.TCP_CORK:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.CorkOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.TCP_QUICKACK:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.QuickAckOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.TCP_MAXSEG:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.MaxSegOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.TCP_KEEPIDLE:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.KeepaliveIdleOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Second), nil
+
+ case linux.TCP_KEEPINTVL:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.KeepaliveIntervalOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Second), nil
+
+ case linux.TCP_KEEPCNT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.KeepaliveCountOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.TCP_USER_TIMEOUT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TCPUserTimeoutOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Millisecond), nil
+
+ case linux.TCP_INFO:
+ var v tcpip.TCPInfoOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ // TODO(b/64800844): Translate fields once they are added to
+ // tcpip.TCPInfoOption.
+ info := linux.TCPInfo{}
+
+ // Linux truncates the output binary to outLen.
+ ib := binary.Marshal(nil, usermem.ByteOrder, &info)
+ if len(ib) > outLen {
+ ib = ib[:outLen]
+ }
+
+ return ib, nil
+
+ case linux.TCP_CC_INFO,
+ linux.TCP_NOTSENT_LOWAT,
+ linux.TCP_ZEROCOPY_RECEIVE:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ case linux.TCP_CONGESTION:
+ if outLen <= 0 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.CongestionControlOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ // We match linux behaviour here where it returns the lower of
+ // TCP_CA_NAME_MAX bytes or the value of the option length.
+ //
+ // This is Linux's net/tcp.h TCP_CA_NAME_MAX.
+ const tcpCANameMax = 16
+
+ toCopy := tcpCANameMax
+ if outLen < tcpCANameMax {
+ toCopy = outLen
+ }
+ b := make([]byte, toCopy)
+ copy(b, v)
+ return b, nil
+
+ case linux.TCP_LINGER2:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TCPLingerTimeoutOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Second), nil
+
+ case linux.TCP_DEFER_ACCEPT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.TCPDeferAcceptOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(time.Duration(v) / time.Second), nil
+
+ case linux.TCP_SYNCNT:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.TCPSynCountOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.TCP_WINDOW_CLAMP:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.TCPWindowClampOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+ default:
+ emitUnimplementedEventTCP(t, name)
+ }
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
+func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) {
+ switch name {
+ case linux.IPV6_V6ONLY:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.V6OnlyOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.IPV6_PATHMTU:
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ case linux.IPV6_TCLASS:
+ // Length handling for parity with Linux.
+ if outLen == 0 {
+ return make([]byte, 0), nil
+ }
+ v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ uintv := uint32(v)
+ // Linux truncates the output binary to outLen.
+ ib := binary.Marshal(nil, usermem.ByteOrder, &uintv)
+ // Handle cases where outLen is lesser than sizeOfInt32.
+ if len(ib) > outLen {
+ ib = ib[:outLen]
+ }
+ return ib, nil
+
+ case linux.IPV6_RECVTCLASS:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ default:
+ emitUnimplementedEventIPv6(t, name)
+ }
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// getSockOptIP implements GetSockOpt when level is SOL_IP.
+func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) {
+ switch name {
+ case linux.IP_TTL:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.TTLOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ // Fill in the default value, if needed.
+ if v == 0 {
+ v = DefaultTTL
+ }
+
+ return int32(v), nil
+
+ case linux.IP_MULTICAST_TTL:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptInt(tcpip.MulticastTTLOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.IP_MULTICAST_IF:
+ if outLen < len(linux.InetAddr{}) {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.MulticastInterfaceOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr})
+
+ return a.(*linux.SockAddrInet).Addr, nil
+
+ case linux.IP_MULTICAST_LOOP:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.IP_TOS:
+ // Length handling for parity with Linux.
+ if outLen == 0 {
+ return []byte(nil), nil
+ }
+ v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ if outLen < sizeOfInt32 {
+ return uint8(v), nil
+ }
+ return int32(v), nil
+
+ case linux.IP_RECVTOS:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ case linux.IP_PKTINFO:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
+ default:
+ emitUnimplementedEventIP(t, name)
+ }
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
+ // implemented specifically for netstack.SocketOperations rather than
+ // commonEndpoint. commonEndpoint should be extended to support socket
+ // options where the implementation is not shared, as unix sockets need
+ // their own support for SO_TIMESTAMP.
+ if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP {
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0
+ return nil
+ }
+ if level == linux.SOL_TCP && name == linux.TCP_INQ {
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0
+ return nil
+ }
+
+ if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
+ switch name {
+ case linux.IPT_SO_SET_REPLACE:
+ if len(optVal) < linux.SizeOfIPTReplace {
+ return syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return syserr.ErrNoDevice
+ }
+ // Stack must be a netstack stack.
+ return netfilter.SetEntries(stack.(*Stack).Stack, optVal)
+
+ case linux.IPT_SO_SET_ADD_COUNTERS:
+ // TODO(gvisor.dev/issue/170): Counter support.
+ return nil
+ }
+ }
+
+ return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
+}
+
+// SetSockOpt can be used to implement the linux syscall setsockopt(2) for
+// sockets backed by a commonEndpoint.
+func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error {
+ switch level {
+ case linux.SOL_SOCKET:
+ return setSockOptSocket(t, s, ep, name, optVal)
+
+ case linux.SOL_TCP:
+ return setSockOptTCP(t, ep, name, optVal)
+
+ case linux.SOL_IPV6:
+ return setSockOptIPv6(t, ep, name, optVal)
+
+ case linux.SOL_IP:
+ return setSockOptIP(t, ep, name, optVal)
+
+ case linux.SOL_UDP,
+ linux.SOL_ICMPV6,
+ linux.SOL_RAW,
+ linux.SOL_PACKET:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+
+ // Default to the old behavior; hand off to network stack.
+ return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+}
+
+// setSockOptSocket implements SetSockOpt when level is SOL_SOCKET.
+func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ switch name {
+ case linux.SO_SNDBUF:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.SendBufferSizeOption, int(v)))
+
+ case linux.SO_RCVBUF:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.ReceiveBufferSizeOption, int(v)))
+
+ case linux.SO_REUSEADDR:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0))
+
+ case linux.SO_REUSEPORT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0))
+
+ case linux.SO_BINDTODEVICE:
+ n := bytes.IndexByte(optVal, 0)
+ if n == -1 {
+ n = len(optVal)
+ }
+ name := string(optVal[:n])
+ if name == "" {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0)))
+ }
+ s := t.NetworkContext()
+ if s == nil {
+ return syserr.ErrNoDevice
+ }
+ for nicID, nic := range s.Interfaces() {
+ if nic.Name == name {
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID)))
+ }
+ }
+ return syserr.ErrUnknownDevice
+
+ case linux.SO_BROADCAST:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.BroadcastOption, v != 0))
+
+ case linux.SO_PASSCRED:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0))
+
+ case linux.SO_KEEPALIVE:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0))
+
+ case linux.SO_SNDTIMEO:
+ if len(optVal) < linux.SizeOfTimeval {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Timeval
+ binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
+ s.SetSendTimeout(v.ToNsecCapped())
+ return nil
+
+ case linux.SO_RCVTIMEO:
+ if len(optVal) < linux.SizeOfTimeval {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Timeval
+ binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) {
+ return syserr.ErrDomain
+ }
+ s.SetRecvTimeout(v.ToNsecCapped())
+ return nil
+
+ case linux.SO_OOBINLINE:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+
+ if v == 0 {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.OutOfBandInlineOption(v)))
+
+ case linux.SO_NO_CHECK:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0))
+
+ case linux.SO_LINGER:
+ if len(optVal) < linux.SizeOfLinger {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Linger
+ binary.Unmarshal(optVal[:linux.SizeOfLinger], usermem.ByteOrder, &v)
+
+ if v != (linux.Linger{}) {
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ }
+
+ return nil
+
+ default:
+ socket.SetSockOptEmitUnimplementedEvent(t, name)
+ }
+
+ // Default to the old behavior; hand off to network stack.
+ return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+}
+
+// setSockOptTCP implements SetSockOpt when level is SOL_TCP.
+func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ switch name {
+ case linux.TCP_NODELAY:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0))
+
+ case linux.TCP_CORK:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0))
+
+ case linux.TCP_QUICKACK:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0))
+
+ case linux.TCP_MAXSEG:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MaxSegOption, int(v)))
+
+ case linux.TCP_KEEPIDLE:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ if v < 1 || v > linux.MAX_TCP_KEEPIDLE {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIdleOption(time.Second * time.Duration(v))))
+
+ case linux.TCP_KEEPINTVL:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ if v < 1 || v > linux.MAX_TCP_KEEPINTVL {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v))))
+
+ case linux.TCP_KEEPCNT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ if v < 1 || v > linux.MAX_TCP_KEEPCNT {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.KeepaliveCountOption, int(v)))
+
+ case linux.TCP_USER_TIMEOUT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+ if v < 0 {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v))))
+
+ case linux.TCP_CONGESTION:
+ v := tcpip.CongestionControlOption(optVal)
+ if err := ep.SetSockOpt(v); err != nil {
+ return syserr.TranslateNetstackError(err)
+ }
+ return nil
+
+ case linux.TCP_LINGER2:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v))))
+
+ case linux.TCP_DEFER_ACCEPT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+ if v < 0 {
+ v = 0
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v))))
+
+ case linux.TCP_SYNCNT:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := usermem.ByteOrder.Uint32(optVal)
+
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPSynCountOption, int(v)))
+
+ case linux.TCP_WINDOW_CLAMP:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := usermem.ByteOrder.Uint32(optVal)
+
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPWindowClampOption, int(v)))
+
+ case linux.TCP_REPAIR_OPTIONS:
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ default:
+ emitUnimplementedEventTCP(t, name)
+ }
+
+ // Default to the old behavior; hand off to network stack.
+ return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+}
+
+// setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6.
+func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ switch name {
+ case linux.IPV6_V6ONLY:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0))
+
+ case linux.IPV6_ADD_MEMBERSHIP,
+ linux.IPV6_DROP_MEMBERSHIP,
+ linux.IPV6_IPSEC_POLICY,
+ linux.IPV6_JOIN_ANYCAST,
+ linux.IPV6_LEAVE_ANYCAST,
+ // TODO(b/148887420): Add support for IPV6_PKTINFO.
+ linux.IPV6_PKTINFO,
+ linux.IPV6_ROUTER_ALERT,
+ linux.IPV6_XFRM_POLICY,
+ linux.MCAST_BLOCK_SOURCE,
+ linux.MCAST_JOIN_GROUP,
+ linux.MCAST_JOIN_SOURCE_GROUP,
+ linux.MCAST_LEAVE_GROUP,
+ linux.MCAST_LEAVE_SOURCE_GROUP,
+ linux.MCAST_UNBLOCK_SOURCE:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ case linux.IPV6_TCLASS:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+ if v < -1 || v > 255 {
+ return syserr.ErrInvalidArgument
+ }
+ if v == -1 {
+ v = 0
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, int(v)))
+
+ case linux.IPV6_RECVTCLASS:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0))
+
+ default:
+ emitUnimplementedEventIPv6(t, name)
+ }
+
+ // Default to the old behavior; hand off to network stack.
+ return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+}
+
+var (
+ inetMulticastRequestSize = int(binary.Size(linux.InetMulticastRequest{}))
+ inetMulticastRequestWithNICSize = int(binary.Size(linux.InetMulticastRequestWithNIC{}))
+)
+
+// copyInMulticastRequest copies in a variable-size multicast request. The
+// kernel determines which structure was passed by its length. IP_MULTICAST_IF
+// supports ip_mreqn, ip_mreq and in_addr, while IP_ADD_MEMBERSHIP and
+// IP_DROP_MEMBERSHIP only support ip_mreqn and ip_mreq. To handle this,
+// allowAddr controls whether in_addr is accepted or rejected.
+func copyInMulticastRequest(optVal []byte, allowAddr bool) (linux.InetMulticastRequestWithNIC, *syserr.Error) {
+ if len(optVal) < len(linux.InetAddr{}) {
+ return linux.InetMulticastRequestWithNIC{}, syserr.ErrInvalidArgument
+ }
+
+ if len(optVal) < inetMulticastRequestSize {
+ if !allowAddr {
+ return linux.InetMulticastRequestWithNIC{}, syserr.ErrInvalidArgument
+ }
+
+ var req linux.InetMulticastRequestWithNIC
+ copy(req.InterfaceAddr[:], optVal)
+ return req, nil
+ }
+
+ if len(optVal) >= inetMulticastRequestWithNICSize {
+ var req linux.InetMulticastRequestWithNIC
+ binary.Unmarshal(optVal[:inetMulticastRequestWithNICSize], usermem.ByteOrder, &req)
+ return req, nil
+ }
+
+ var req linux.InetMulticastRequestWithNIC
+ binary.Unmarshal(optVal[:inetMulticastRequestSize], usermem.ByteOrder, &req.InetMulticastRequest)
+ return req, nil
+}
+
+// parseIntOrChar copies either a 32-bit int or an 8-bit uint out of buf.
+//
+// net/ipv4/ip_sockglue.c:do_ip_setsockopt does this for its socket options.
+func parseIntOrChar(buf []byte) (int32, *syserr.Error) {
+ if len(buf) == 0 {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ if len(buf) >= sizeOfInt32 {
+ return int32(usermem.ByteOrder.Uint32(buf)), nil
+ }
+
+ return int32(buf[0]), nil
+}
+
+// setSockOptIP implements SetSockOpt when level is SOL_IP.
+func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+ switch name {
+ case linux.IP_MULTICAST_TTL:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ if v == -1 {
+ // Linux translates -1 to 1.
+ v = 1
+ }
+ if v < 0 || v > 255 {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MulticastTTLOption, int(v)))
+
+ case linux.IP_ADD_MEMBERSHIP:
+ req, err := copyInMulticastRequest(optVal, false /* allowAddr */)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.AddMembershipOption{
+ NIC: tcpip.NICID(req.InterfaceIndex),
+ // TODO(igudger): Change AddMembership to use the standard
+ // any address representation.
+ InterfaceAddr: tcpip.Address(req.InterfaceAddr[:]),
+ MulticastAddr: tcpip.Address(req.MulticastAddr[:]),
+ }))
+
+ case linux.IP_DROP_MEMBERSHIP:
+ req, err := copyInMulticastRequest(optVal, false /* allowAddr */)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.RemoveMembershipOption{
+ NIC: tcpip.NICID(req.InterfaceIndex),
+ // TODO(igudger): Change DropMembership to use the standard
+ // any address representation.
+ InterfaceAddr: tcpip.Address(req.InterfaceAddr[:]),
+ MulticastAddr: tcpip.Address(req.MulticastAddr[:]),
+ }))
+
+ case linux.IP_MULTICAST_IF:
+ req, err := copyInMulticastRequest(optVal, true /* allowAddr */)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastInterfaceOption{
+ NIC: tcpip.NICID(req.InterfaceIndex),
+ InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]),
+ }))
+
+ case linux.IP_MULTICAST_LOOP:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0))
+
+ case linux.MCAST_JOIN_GROUP:
+ // FIXME(b/124219304): Implement MCAST_JOIN_GROUP.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return syserr.ErrInvalidArgument
+
+ case linux.IP_TTL:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ // -1 means default TTL.
+ if v == -1 {
+ v = 0
+ } else if v < 1 || v > 255 {
+ return syserr.ErrInvalidArgument
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TTLOption, int(v)))
+
+ case linux.IP_TOS:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv4TOSOption, int(v)))
+
+ case linux.IP_RECVTOS:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0))
+
+ case linux.IP_PKTINFO:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+
+ case linux.IP_HDRINCL:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0))
+
+ case linux.IP_ADD_SOURCE_MEMBERSHIP,
+ linux.IP_BIND_ADDRESS_NO_PORT,
+ linux.IP_BLOCK_SOURCE,
+ linux.IP_CHECKSUM,
+ linux.IP_DROP_SOURCE_MEMBERSHIP,
+ linux.IP_FREEBIND,
+ linux.IP_IPSEC_POLICY,
+ linux.IP_MINTTL,
+ linux.IP_MSFILTER,
+ linux.IP_MTU_DISCOVER,
+ linux.IP_MULTICAST_ALL,
+ linux.IP_NODEFRAG,
+ linux.IP_OPTIONS,
+ linux.IP_PASSSEC,
+ linux.IP_RECVERR,
+ linux.IP_RECVFRAGSIZE,
+ linux.IP_RECVOPTS,
+ linux.IP_RECVORIGDSTADDR,
+ linux.IP_RECVTTL,
+ linux.IP_RETOPTS,
+ linux.IP_TRANSPARENT,
+ linux.IP_UNBLOCK_SOURCE,
+ linux.IP_UNICAST_IF,
+ linux.IP_XFRM_POLICY,
+ linux.MCAST_BLOCK_SOURCE,
+ linux.MCAST_JOIN_SOURCE_GROUP,
+ linux.MCAST_LEAVE_GROUP,
+ linux.MCAST_LEAVE_SOURCE_GROUP,
+ linux.MCAST_MSFILTER,
+ linux.MCAST_UNBLOCK_SOURCE:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+
+ // Default to the old behavior; hand off to network stack.
+ return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+}
+
+// emitUnimplementedEventTCP emits unimplemented event if name is valid. This
+// function contains names that are common between Get and SetSockOpt when
+// level is SOL_TCP.
+func emitUnimplementedEventTCP(t *kernel.Task, name int) {
+ switch name {
+ case linux.TCP_CONGESTION,
+ linux.TCP_CORK,
+ linux.TCP_FASTOPEN,
+ linux.TCP_FASTOPEN_CONNECT,
+ linux.TCP_FASTOPEN_KEY,
+ linux.TCP_FASTOPEN_NO_COOKIE,
+ linux.TCP_QUEUE_SEQ,
+ linux.TCP_REPAIR,
+ linux.TCP_REPAIR_QUEUE,
+ linux.TCP_REPAIR_WINDOW,
+ linux.TCP_SAVED_SYN,
+ linux.TCP_SAVE_SYN,
+ linux.TCP_THIN_DUPACK,
+ linux.TCP_THIN_LINEAR_TIMEOUTS,
+ linux.TCP_TIMESTAMP,
+ linux.TCP_ULP:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+}
+
+// emitUnimplementedEventIPv6 emits unimplemented event if name is valid. It
+// contains names that are common between Get and SetSockOpt when level is
+// SOL_IPV6.
+func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
+ switch name {
+ case linux.IPV6_2292DSTOPTS,
+ linux.IPV6_2292HOPLIMIT,
+ linux.IPV6_2292HOPOPTS,
+ linux.IPV6_2292PKTINFO,
+ linux.IPV6_2292PKTOPTIONS,
+ linux.IPV6_2292RTHDR,
+ linux.IPV6_ADDR_PREFERENCES,
+ linux.IPV6_AUTOFLOWLABEL,
+ linux.IPV6_DONTFRAG,
+ linux.IPV6_DSTOPTS,
+ linux.IPV6_FLOWINFO,
+ linux.IPV6_FLOWINFO_SEND,
+ linux.IPV6_FLOWLABEL_MGR,
+ linux.IPV6_FREEBIND,
+ linux.IPV6_HOPOPTS,
+ linux.IPV6_MINHOPCOUNT,
+ linux.IPV6_MTU,
+ linux.IPV6_MTU_DISCOVER,
+ linux.IPV6_MULTICAST_ALL,
+ linux.IPV6_MULTICAST_HOPS,
+ linux.IPV6_MULTICAST_IF,
+ linux.IPV6_MULTICAST_LOOP,
+ linux.IPV6_RECVDSTOPTS,
+ linux.IPV6_RECVERR,
+ linux.IPV6_RECVFRAGSIZE,
+ linux.IPV6_RECVHOPLIMIT,
+ linux.IPV6_RECVHOPOPTS,
+ linux.IPV6_RECVORIGDSTADDR,
+ linux.IPV6_RECVPATHMTU,
+ linux.IPV6_RECVPKTINFO,
+ linux.IPV6_RECVRTHDR,
+ linux.IPV6_RTHDR,
+ linux.IPV6_RTHDRDSTOPTS,
+ linux.IPV6_TCLASS,
+ linux.IPV6_TRANSPARENT,
+ linux.IPV6_UNICAST_HOPS,
+ linux.IPV6_UNICAST_IF,
+ linux.MCAST_MSFILTER,
+ linux.IPV6_ADDRFORM:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+}
+
+// emitUnimplementedEventIP emits unimplemented event if name is valid. It
+// contains names that are common between Get and SetSockOpt when level is
+// SOL_IP.
+func emitUnimplementedEventIP(t *kernel.Task, name int) {
+ switch name {
+ case linux.IP_TOS,
+ linux.IP_TTL,
+ linux.IP_HDRINCL,
+ linux.IP_OPTIONS,
+ linux.IP_ROUTER_ALERT,
+ linux.IP_RECVOPTS,
+ linux.IP_RETOPTS,
+ linux.IP_PKTINFO,
+ linux.IP_PKTOPTIONS,
+ linux.IP_MTU_DISCOVER,
+ linux.IP_RECVERR,
+ linux.IP_RECVTTL,
+ linux.IP_RECVTOS,
+ linux.IP_MTU,
+ linux.IP_FREEBIND,
+ linux.IP_IPSEC_POLICY,
+ linux.IP_XFRM_POLICY,
+ linux.IP_PASSSEC,
+ linux.IP_TRANSPARENT,
+ linux.IP_ORIGDSTADDR,
+ linux.IP_MINTTL,
+ linux.IP_NODEFRAG,
+ linux.IP_CHECKSUM,
+ linux.IP_BIND_ADDRESS_NO_PORT,
+ linux.IP_RECVFRAGSIZE,
+ linux.IP_MULTICAST_IF,
+ linux.IP_MULTICAST_TTL,
+ linux.IP_MULTICAST_LOOP,
+ linux.IP_ADD_MEMBERSHIP,
+ linux.IP_DROP_MEMBERSHIP,
+ linux.IP_UNBLOCK_SOURCE,
+ linux.IP_BLOCK_SOURCE,
+ linux.IP_ADD_SOURCE_MEMBERSHIP,
+ linux.IP_DROP_SOURCE_MEMBERSHIP,
+ linux.IP_MSFILTER,
+ linux.MCAST_JOIN_GROUP,
+ linux.MCAST_BLOCK_SOURCE,
+ linux.MCAST_UNBLOCK_SOURCE,
+ linux.MCAST_LEAVE_GROUP,
+ linux.MCAST_JOIN_SOURCE_GROUP,
+ linux.MCAST_LEAVE_SOURCE_GROUP,
+ linux.MCAST_MSFILTER,
+ linux.IP_MULTICAST_ALL,
+ linux.IP_UNICAST_IF:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+}
+
+// isLinkLocal determines if the given IPv6 address is link-local. This is the
+// case when it has the fe80::/10 prefix. This check is used to determine when
+// the NICID is relevant for a given IPv6 address.
+func isLinkLocal(addr tcpip.Address) bool {
+ return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80
+}
+
+// ConvertAddress converts the given address to a native format.
+func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) {
+ switch family {
+ case linux.AF_UNIX:
+ var out linux.SockAddrUnix
+ out.Family = linux.AF_UNIX
+ l := len([]byte(addr.Addr))
+ for i := 0; i < l; i++ {
+ out.Path[i] = int8(addr.Addr[i])
+ }
+
+ // Linux returns the used length of the address struct (including the
+ // null terminator) for filesystem paths. The Family field is 2 bytes.
+ // It is sometimes allowed to exclude the null terminator if the
+ // address length is the max. Abstract and empty paths always return
+ // the full exact length.
+ if l == 0 || out.Path[0] == 0 || l == len(out.Path) {
+ return &out, uint32(2 + l)
+ }
+ return &out, uint32(3 + l)
+
+ case linux.AF_INET:
+ var out linux.SockAddrInet
+ copy(out.Addr[:], addr.Addr)
+ out.Family = linux.AF_INET
+ out.Port = htons(addr.Port)
+ return &out, uint32(sockAddrInetSize)
+
+ case linux.AF_INET6:
+ var out linux.SockAddrInet6
+ if len(addr.Addr) == header.IPv4AddressSize {
+ // Copy address in v4-mapped format.
+ copy(out.Addr[12:], addr.Addr)
+ out.Addr[10] = 0xff
+ out.Addr[11] = 0xff
+ } else {
+ copy(out.Addr[:], addr.Addr)
+ }
+ out.Family = linux.AF_INET6
+ out.Port = htons(addr.Port)
+ if isLinkLocal(addr.Addr) {
+ out.Scope_id = uint32(addr.NIC)
+ }
+ return &out, uint32(sockAddrInet6Size)
+
+ case linux.AF_PACKET:
+ // TODO(b/129292371): Return protocol too.
+ var out linux.SockAddrLink
+ out.Family = linux.AF_PACKET
+ out.InterfaceIndex = int32(addr.NIC)
+ out.HardwareAddrLen = header.EthernetAddressSize
+ copy(out.HardwareAddr[:], addr.Addr)
+ return &out, uint32(sockAddrLinkSize)
+
+ default:
+ return nil, 0
+ }
+}
+
+// GetSockName implements the linux syscall getsockname(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ addr, err := s.Endpoint.GetLocalAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := ConvertAddress(s.family, addr)
+ return a, l, nil
+}
+
+// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ addr, err := s.Endpoint.GetRemoteAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := ConvertAddress(s.family, addr)
+ return a, l, nil
+}
+
+// coalescingRead is the fast path for non-blocking, non-peek, stream-based
+// case. It coalesces as many packets as possible before returning to the
+// caller.
+//
+// Precondition: s.readMu must be locked.
+func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) {
+ var err *syserr.Error
+ var copied int
+
+ // Copy as many views as possible into the user-provided buffer.
+ for {
+ // Always do at least one fetchReadView, even if the number of bytes to
+ // read is 0.
+ err = s.fetchReadView()
+ if err != nil {
+ break
+ }
+ if dst.NumBytes() == 0 {
+ break
+ }
+
+ var n int
+ var e error
+ if discard {
+ n = len(s.readView)
+ if int64(n) > dst.NumBytes() {
+ n = int(dst.NumBytes())
+ }
+ } else {
+ n, e = dst.CopyOut(ctx, s.readView)
+ // Set the control message, even if 0 bytes were read.
+ if e == nil {
+ s.updateTimestamp()
+ }
+ }
+ copied += n
+ s.readView.TrimFront(n)
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
+ }
+
+ dst = dst.DropFirst(n)
+ if e != nil {
+ err = syserr.FromError(e)
+ break
+ }
+ }
+
+ // If we managed to copy something, we must deliver it.
+ if copied > 0 {
+ s.Endpoint.ModerateRecvBuf(copied)
+ return copied, nil
+ }
+
+ return 0, err
+}
+
+func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) {
+ if !s.sockOptInq {
+ return
+ }
+ rcvBufUsed, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if err != nil {
+ return
+ }
+ cmsg.IP.HasInq = true
+ cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed)
+}
+
+// nonBlockingRead issues a non-blocking read.
+//
+// TODO(b/78348848): Support timestamps for stream sockets.
+func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) {
+ isPacket := s.isPacketBased()
+
+ // Fast path for regular reads from stream (e.g., TCP) endpoints. Note
+ // that senderRequested is ignored for stream sockets.
+ if !peek && !isPacket {
+ // TCP sockets discard the data if MSG_TRUNC is set.
+ //
+ // This behavior is documented in man 7 tcp:
+ // Since version 2.4, Linux supports the use of MSG_TRUNC in the flags
+ // argument of recv(2) (and recvmsg(2)). This flag causes the received
+ // bytes of data to be discarded, rather than passed back in a
+ // caller-supplied buffer.
+ s.readMu.Lock()
+ n, err := s.coalescingRead(ctx, dst, trunc)
+ cmsg := s.controlMessages()
+ s.fillCmsgInq(&cmsg)
+ s.readMu.Unlock()
+ return n, 0, nil, 0, cmsg, err
+ }
+
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+
+ if err := s.fetchReadView(); err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, err
+ }
+
+ if !isPacket && peek && trunc {
+ // MSG_TRUNC with MSG_PEEK on a TCP socket returns the
+ // amount that could be read.
+ rql, err := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if err != nil {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.TranslateNetstackError(err)
+ }
+ available := len(s.readView) + int(rql)
+ bufLen := int(dst.NumBytes())
+ if available < bufLen {
+ return available, 0, nil, 0, socket.ControlMessages{}, nil
+ }
+ return bufLen, 0, nil, 0, socket.ControlMessages{}, nil
+ }
+
+ n, err := dst.CopyOut(ctx, s.readView)
+ // Set the control message, even if 0 bytes were read.
+ if err == nil {
+ s.updateTimestamp()
+ }
+ var addr linux.SockAddr
+ var addrLen uint32
+ if isPacket && senderRequested {
+ addr, addrLen = ConvertAddress(s.family, s.sender)
+ }
+
+ if peek {
+ if l := len(s.readView); trunc && l > n {
+ // isPacket must be true.
+ return l, linux.MSG_TRUNC, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ }
+
+ if isPacket || err != nil {
+ return n, 0, addr, addrLen, s.controlMessages(), syserr.FromError(err)
+ }
+
+ // We need to peek beyond the first message.
+ dst = dst.DropFirst(n)
+ num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) {
+ n, _, err := s.Endpoint.Peek(dsts)
+ // TODO(b/78348848): Handle peek timestamp.
+ if err != nil {
+ return int64(n), syserr.TranslateNetstackError(err).ToError()
+ }
+ return int64(n), nil
+ }})
+ n += int(num)
+ if err == syserror.ErrWouldBlock && n > 0 {
+ // We got some data, so no need to return an error.
+ err = nil
+ }
+ return n, 0, nil, 0, s.controlMessages(), syserr.FromError(err)
+ }
+
+ var msgLen int
+ if isPacket {
+ msgLen = len(s.readView)
+ s.readView = nil
+ } else {
+ msgLen = int(n)
+ s.readView.TrimFront(int(n))
+ }
+
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
+ }
+
+ var flags int
+ if msgLen > int(n) {
+ flags |= linux.MSG_TRUNC
+ }
+
+ if trunc {
+ n = msgLen
+ }
+
+ cmsg := s.controlMessages()
+ s.fillCmsgInq(&cmsg)
+ return n, flags, addr, addrLen, cmsg, syserr.FromError(err)
+}
+
+func (s *socketOpsCommon) controlMessages() socket.ControlMessages {
+ return socket.ControlMessages{
+ IP: tcpip.ControlMessages{
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
+ },
+ }
+}
+
+// updateTimestamp sets the timestamp for SIOCGSTAMP. It should be called after
+// successfully writing packet data out to userspace.
+//
+// Precondition: s.readMu must be locked.
+func (s *socketOpsCommon) updateTimestamp() {
+ // Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled.
+ if !s.sockOptTimestamp {
+ s.timestampValid = true
+ s.timestampNS = s.readCM.Timestamp
+ }
+}
+
+// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+ trunc := flags&linux.MSG_TRUNC != 0
+ peek := flags&linux.MSG_PEEK != 0
+ dontWait := flags&linux.MSG_DONTWAIT != 0
+ waitAll := flags&linux.MSG_WAITALL != 0
+ if senderRequested && !s.isPacketBased() {
+ // Stream sockets ignore the sender address.
+ senderRequested = false
+ }
+ n, msgFlags, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
+
+ if s.isPacketBased() && err == syserr.ErrClosedForReceive && flags&linux.MSG_DONTWAIT != 0 {
+ // In this situation we should return EAGAIN.
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+
+ if err != nil && (err != syserr.ErrWouldBlock || dontWait) {
+ // Read failed and we should not retry.
+ return 0, 0, nil, 0, socket.ControlMessages{}, err
+ }
+
+ if err == nil && (dontWait || !waitAll || s.isPacketBased() || int64(n) >= dst.NumBytes()) {
+ // We got all the data we need.
+ return
+ }
+
+ // Don't overwrite any data we received.
+ dst = dst.DropFirst(n)
+
+ // We'll have to block. Register for notifications and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ for {
+ var rn int
+ rn, msgFlags, senderAddr, senderAddrLen, controlMessages, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
+ n += rn
+ if err != nil && err != syserr.ErrWouldBlock {
+ // Always stop on errors other than would block as we generally
+ // won't be able to get any more data. Eat the error if we got
+ // any data.
+ if n > 0 {
+ err = nil
+ }
+ return
+ }
+ if err == nil && (s.isPacketBased() || !waitAll || int64(rn) >= dst.NumBytes()) {
+ // We got all the data we need.
+ return
+ }
+ dst = dst.DropFirst(rn)
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if n > 0 {
+ return n, msgFlags, senderAddr, senderAddrLen, controlMessages, nil
+ }
+ if err == syserror.ETIMEDOUT {
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+ return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+ }
+}
+
+// SendMsg implements the linux syscall sendmsg(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ // Reject Unix control messages.
+ if !controlMessages.Unix.Empty() {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ var addr *tcpip.FullAddress
+ if len(to) > 0 {
+ addrBuf, family, err := AddressAndFamily(to)
+ if err != nil {
+ return 0, err
+ }
+ if err := s.checkFamily(family, false /* exact */); err != nil {
+ return 0, err
+ }
+ addrBuf = s.mapFamily(addrBuf, family)
+
+ addr = &addrBuf
+ }
+
+ opts := tcpip.WriteOptions{
+ To: addr,
+ More: flags&linux.MSG_MORE != 0,
+ EndOfRecord: flags&linux.MSG_EOR != 0,
+ }
+
+ v := &ioSequencePayload{t, src}
+ n, resCh, err := s.Endpoint.Write(v, opts)
+ if resCh != nil {
+ if err := t.Block(resCh); err != nil {
+ return 0, syserr.FromError(err)
+ }
+ n, _, err = s.Endpoint.Write(v, opts)
+ }
+ dontWait := flags&linux.MSG_DONTWAIT != 0
+ if err == nil && (n >= v.src.NumBytes() || dontWait) {
+ // Complete write.
+ return int(n), nil
+ }
+ if err != nil && (err != tcpip.ErrWouldBlock || dontWait) {
+ return int(n), syserr.TranslateNetstackError(err)
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+
+ v.DropFirst(int(n))
+ total := n
+ for {
+ n, _, err = s.Endpoint.Write(v, opts)
+ v.DropFirst(int(n))
+ total += n
+
+ if err != nil && err != tcpip.ErrWouldBlock && total == 0 {
+ return 0, syserr.TranslateNetstackError(err)
+ }
+
+ if err == nil && v.src.NumBytes() == 0 || err != nil && err != tcpip.ErrWouldBlock {
+ return int(total), nil
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return int(total), syserr.ErrTryAgain
+ }
+ // handleIOError will consume errors from t.Block if needed.
+ return int(total), syserr.FromError(err)
+ }
+ }
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return s.socketOpsCommon.ioctl(ctx, io, args)
+}
+
+func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // SIOCGSTAMP is implemented by netstack rather than all commonEndpoint
+ // sockets.
+ // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP.
+ switch args[1].Int() {
+ case syscall.SIOCGSTAMP:
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if !s.timestampValid {
+ return 0, syserror.ENOENT
+ }
+
+ tv := linux.NsecToTimeval(s.timestampNS)
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &tv, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCINQ:
+ v, terr := s.Endpoint.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
+ }
+
+ // Add bytes removed from the endpoint but not yet sent to the caller.
+ s.readMu.Lock()
+ v += len(s.readView)
+ s.readMu.Unlock()
+
+ if v > math.MaxInt32 {
+ v = math.MaxInt32
+ }
+
+ // Copy result to userspace.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ }
+
+ return Ioctl(ctx, s.Endpoint, io, args)
+}
+
+// Ioctl performs a socket ioctl.
+func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch arg := int(args[1].Int()); arg {
+ case syscall.SIOCGIFFLAGS,
+ syscall.SIOCGIFADDR,
+ syscall.SIOCGIFBRDADDR,
+ syscall.SIOCGIFDSTADDR,
+ syscall.SIOCGIFHWADDR,
+ syscall.SIOCGIFINDEX,
+ syscall.SIOCGIFMAP,
+ syscall.SIOCGIFMETRIC,
+ syscall.SIOCGIFMTU,
+ syscall.SIOCGIFNAME,
+ syscall.SIOCGIFNETMASK,
+ syscall.SIOCGIFTXQLEN:
+
+ var ifr linux.IFReq
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ if err := interfaceIoctl(ctx, io, arg, &ifr); err != nil {
+ return 0, err.ToError()
+ }
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case syscall.SIOCGIFCONF:
+ // Return a list of interface addresses or the buffer size
+ // necessary to hold the list.
+ var ifc linux.IFConf
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ if err := ifconfIoctl(ctx, io, &ifc); err != nil {
+ return 0, err
+ }
+
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+
+ return 0, err
+
+ case linux.TIOCINQ:
+ v, terr := ep.GetSockOptInt(tcpip.ReceiveQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
+ }
+
+ if v > math.MaxInt32 {
+ v = math.MaxInt32
+ }
+ // Copy result to userspace.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCOUTQ:
+ v, terr := ep.GetSockOptInt(tcpip.SendQueueSizeOption)
+ if terr != nil {
+ return 0, syserr.TranslateNetstackError(terr).ToError()
+ }
+
+ if v > math.MaxInt32 {
+ v = math.MaxInt32
+ }
+
+ // Copy result to userspace.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG:
+ unimpl.EmitUnimplementedEvent(ctx)
+ }
+
+ return 0, syserror.ENOTTY
+}
+
+// interfaceIoctl implements interface requests.
+func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error {
+ var (
+ iface inet.Interface
+ index int32
+ found bool
+ )
+
+ // Find the relevant device.
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ return syserr.ErrNoDevice
+ }
+
+ // SIOCGIFNAME uses ifr.ifr_ifindex rather than ifr.ifr_name to
+ // identify a device.
+ if arg == syscall.SIOCGIFNAME {
+ // Gets the name of the interface given the interface index
+ // stored in ifr_ifindex.
+ index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4]))
+ if iface, ok := stack.Interfaces()[index]; ok {
+ ifr.SetName(iface.Name)
+ return nil
+ }
+ return syserr.ErrNoDevice
+ }
+
+ // Find the relevant device.
+ for index, iface = range stack.Interfaces() {
+ if iface.Name == ifr.Name() {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return syserr.ErrNoDevice
+ }
+
+ switch arg {
+ case syscall.SIOCGIFINDEX:
+ // Copy out the index to the data.
+ usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index))
+
+ case syscall.SIOCGIFHWADDR:
+ // Copy the hardware address out.
+ ifr.Data[0] = 6 // IEEE802.2 arp type.
+ ifr.Data[1] = 0
+ n := copy(ifr.Data[2:], iface.Addr)
+ for i := 2 + n; i < len(ifr.Data); i++ {
+ ifr.Data[i] = 0 // Clear padding.
+ }
+ usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n))
+
+ case syscall.SIOCGIFFLAGS:
+ f, err := interfaceStatusFlags(stack, iface.Name)
+ if err != nil {
+ return err
+ }
+ // Drop the flags that don't fit in the size that we need to return. This
+ // matches Linux behavior.
+ usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f))
+
+ case syscall.SIOCGIFADDR:
+ // Copy the IPv4 address out.
+ for _, addr := range stack.InterfaceAddrs()[index] {
+ // This ioctl is only compatible with AF_INET addresses.
+ if addr.Family != linux.AF_INET {
+ continue
+ }
+ copy(ifr.Data[4:8], addr.Addr)
+ break
+ }
+
+ case syscall.SIOCGIFMETRIC:
+ // Gets the metric of the device. As per netdevice(7), this
+ // always just sets ifr_metric to 0.
+ usermem.ByteOrder.PutUint32(ifr.Data[:4], 0)
+
+ case syscall.SIOCGIFMTU:
+ // Gets the MTU of the device.
+ usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU)
+
+ case syscall.SIOCGIFMAP:
+ // Gets the hardware parameters of the device.
+ // TODO(gvisor.dev/issue/505): Implement.
+
+ case syscall.SIOCGIFTXQLEN:
+ // Gets the transmit queue length of the device.
+ // TODO(gvisor.dev/issue/505): Implement.
+
+ case syscall.SIOCGIFDSTADDR:
+ // Gets the destination address of a point-to-point device.
+ // TODO(gvisor.dev/issue/505): Implement.
+
+ case syscall.SIOCGIFBRDADDR:
+ // Gets the broadcast address of a device.
+ // TODO(gvisor.dev/issue/505): Implement.
+
+ case syscall.SIOCGIFNETMASK:
+ // Gets the network mask of a device.
+ for _, addr := range stack.InterfaceAddrs()[index] {
+ // This ioctl is only compatible with AF_INET addresses.
+ if addr.Family != linux.AF_INET {
+ continue
+ }
+ // Populate ifr.ifr_netmask (type sockaddr).
+ usermem.ByteOrder.PutUint16(ifr.Data[0:2], uint16(linux.AF_INET))
+ usermem.ByteOrder.PutUint16(ifr.Data[2:4], 0)
+ var mask uint32 = 0xffffffff << (32 - addr.PrefixLen)
+ // Netmask is expected to be returned as a big endian
+ // value.
+ binary.BigEndian.PutUint32(ifr.Data[4:8], mask)
+ break
+ }
+
+ default:
+ // Not a valid call.
+ return syserr.ErrInvalidArgument
+ }
+
+ return nil
+}
+
+// ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl.
+func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
+ // If Ptr is NULL, return the necessary buffer size via Len.
+ // Otherwise, write up to Len bytes starting at Ptr containing ifreq
+ // structs.
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ return syserr.ErrNoDevice.ToError()
+ }
+
+ if ifc.Ptr == 0 {
+ ifc.Len = int32(len(stack.Interfaces())) * int32(linux.SizeOfIFReq)
+ return nil
+ }
+
+ max := ifc.Len
+ ifc.Len = 0
+ for key, ifaceAddrs := range stack.InterfaceAddrs() {
+ iface := stack.Interfaces()[key]
+ for _, ifaceAddr := range ifaceAddrs {
+ // Don't write past the end of the buffer.
+ if ifc.Len+int32(linux.SizeOfIFReq) > max {
+ break
+ }
+ if ifaceAddr.Family != linux.AF_INET {
+ continue
+ }
+
+ // Populate ifr.ifr_addr.
+ ifr := linux.IFReq{}
+ ifr.SetName(iface.Name)
+ usermem.ByteOrder.PutUint16(ifr.Data[0:2], uint16(ifaceAddr.Family))
+ usermem.ByteOrder.PutUint16(ifr.Data[2:4], 0)
+ copy(ifr.Data[4:8], ifaceAddr.Addr[:4])
+
+ // Copy the ifr to userspace.
+ dst := uintptr(ifc.Ptr) + uintptr(ifc.Len)
+ ifc.Len += int32(linux.SizeOfIFReq)
+ if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// interfaceStatusFlags returns status flags for an interface in the stack.
+// Flag values and meanings are described in greater detail in netdevice(7) in
+// the SIOCGIFFLAGS section.
+func interfaceStatusFlags(stack inet.Stack, name string) (uint32, *syserr.Error) {
+ // We should only ever be passed a netstack.Stack.
+ epstack, ok := stack.(*Stack)
+ if !ok {
+ return 0, errStackType
+ }
+
+ // Find the NIC corresponding to this interface.
+ for _, info := range epstack.Stack.NICInfo() {
+ if info.Name == name {
+ return nicStateFlagsToLinux(info.Flags), nil
+ }
+ }
+ return 0, syserr.ErrNoDevice
+}
+
+func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 {
+ var rv uint32
+ if f.Up {
+ rv |= linux.IFF_UP | linux.IFF_LOWER_UP
+ }
+ if f.Running {
+ rv |= linux.IFF_RUNNING
+ }
+ if f.Promiscuous {
+ rv |= linux.IFF_PROMISC
+ }
+ if f.Loopback {
+ rv |= linux.IFF_LOOPBACK
+ }
+ return rv
+}
+
+// State implements socket.Socket.State. State translates the internal state
+// returned by netstack to values defined by Linux.
+func (s *socketOpsCommon) State() uint32 {
+ if s.family != linux.AF_INET && s.family != linux.AF_INET6 {
+ // States not implemented for this socket's family.
+ return 0
+ }
+
+ switch {
+ case s.skType == linux.SOCK_STREAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_TCP:
+ // TCP socket.
+ switch tcp.EndpointState(s.Endpoint.State()) {
+ case tcp.StateEstablished:
+ return linux.TCP_ESTABLISHED
+ case tcp.StateSynSent:
+ return linux.TCP_SYN_SENT
+ case tcp.StateSynRecv:
+ return linux.TCP_SYN_RECV
+ case tcp.StateFinWait1:
+ return linux.TCP_FIN_WAIT1
+ case tcp.StateFinWait2:
+ return linux.TCP_FIN_WAIT2
+ case tcp.StateTimeWait:
+ return linux.TCP_TIME_WAIT
+ case tcp.StateClose, tcp.StateInitial, tcp.StateBound, tcp.StateConnecting, tcp.StateError:
+ return linux.TCP_CLOSE
+ case tcp.StateCloseWait:
+ return linux.TCP_CLOSE_WAIT
+ case tcp.StateLastAck:
+ return linux.TCP_LAST_ACK
+ case tcp.StateListen:
+ return linux.TCP_LISTEN
+ case tcp.StateClosing:
+ return linux.TCP_CLOSING
+ default:
+ // Internal or unknown state.
+ return 0
+ }
+ case s.skType == linux.SOCK_DGRAM && s.protocol == 0 || s.protocol == syscall.IPPROTO_UDP:
+ // UDP socket.
+ switch udp.EndpointState(s.Endpoint.State()) {
+ case udp.StateInitial, udp.StateBound, udp.StateClosed:
+ return linux.TCP_CLOSE
+ case udp.StateConnected:
+ return linux.TCP_ESTABLISHED
+ default:
+ return 0
+ }
+ case s.skType == linux.SOCK_DGRAM && s.protocol == syscall.IPPROTO_ICMP || s.protocol == syscall.IPPROTO_ICMPV6:
+ // TODO(b/112063468): Export states for ICMP sockets.
+ case s.skType == linux.SOCK_RAW:
+ // TODO(b/112063468): Export states for raw sockets.
+ default:
+ // Unknown transport protocol, how did we make this socket?
+ log.Warningf("Unknown transport protocol for an existing socket: family=%v, type=%v, protocol=%v, internal type %v", s.family, s.skType, s.protocol, reflect.TypeOf(s.Endpoint).Elem())
+ return 0
+ }
+
+ return 0
+}
+
+// Type implements socket.Socket.Type.
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
+ return s.family, s.skType, s.protocol
+}
+
+// LINT.ThenChange(./netstack_vfs2.go)
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
new file mode 100644
index 000000000..d65a89316
--- /dev/null
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -0,0 +1,330 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/amutex"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SocketVFS2 encapsulates all the state needed to represent a network stack
+// endpoint in the kernel context.
+type SocketVFS2 struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
+
+ socketOpsCommon
+}
+
+var _ = socket.SocketVFS2(&SocketVFS2{})
+
+// NewVFS2 creates a new endpoint socket.
+func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) {
+ if skType == linux.SOCK_STREAM {
+ if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ }
+
+ mnt := t.Kernel().SocketMount()
+ d := sockfs.NewDentry(t.Credentials(), mnt)
+
+ s := &SocketVFS2{
+ socketOpsCommon: socketOpsCommon{
+ Queue: queue,
+ family: family,
+ Endpoint: endpoint,
+ skType: skType,
+ protocol: protocol,
+ },
+ }
+ s.LockFD.Init(&vfs.FileLocks{})
+ vfsfd := &s.vfsfd
+ if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return vfsfd, nil
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.socketOpsCommon.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.socketOpsCommon.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SocketVFS2) EventUnregister(e *waiter.Entry) {
+ s.socketOpsCommon.EventUnregister(e)
+}
+
+// Read implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
+ if err == syserr.ErrWouldBlock {
+ return int64(n), syserror.ErrWouldBlock
+ }
+ if err != nil {
+ return 0, err.ToError()
+ }
+ return int64(n), nil
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ f := &ioSequencePayload{ctx: ctx, src: src}
+ n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
+ if err == tcpip.ErrWouldBlock {
+ return 0, syserror.ErrWouldBlock
+ }
+
+ if resCh != nil {
+ if err := amutex.Block(ctx, resCh); err != nil {
+ return 0, err
+ }
+ n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{})
+ }
+
+ if err != nil {
+ return 0, syserr.TranslateNetstackError(err).ToError()
+ }
+
+ if int64(n) < src.NumBytes() {
+ return int64(n), syserror.ErrWouldBlock
+ }
+
+ return int64(n), nil
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, wq, terr := s.Endpoint.Accept()
+ if terr != nil {
+ if terr != tcpip.ErrWouldBlock || !blocking {
+ return 0, nil, 0, syserr.TranslateNetstackError(terr)
+ }
+
+ var err *syserr.Error
+ ep, wq, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ ns, err := NewVFS2(t, s.family, s.skType, s.protocol, wq, ep)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ defer ns.DecRef()
+
+ if err := ns.SetStatusFlags(t, t.Credentials(), uint32(flags&linux.SOCK_NONBLOCK)); err != nil {
+ return 0, nil, 0, syserr.FromError(err)
+ }
+
+ var addr linux.SockAddr
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer and write it to peer slice.
+ var err *syserr.Error
+ addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ })
+
+ t.Kernel().RecordSocketVFS2(ns)
+
+ return fd, addr, addrLen, syserr.FromError(e)
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return s.socketOpsCommon.ioctl(ctx, uio, args)
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+ // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
+ // implemented specifically for netstack.SocketVFS2 rather than
+ // commonEndpoint. commonEndpoint should be extended to support socket
+ // options where the implementation is not shared, as unix sockets need
+ // their own support for SO_TIMESTAMP.
+ if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP {
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ val := int32(0)
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if s.sockOptTimestamp {
+ val = 1
+ }
+ return val, nil
+ }
+ if level == linux.SOL_TCP && name == linux.TCP_INQ {
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ val := int32(0)
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ if s.sockOptInq {
+ val = 1
+ }
+ return val, nil
+ }
+
+ if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
+ switch name {
+ case linux.IPT_SO_GET_INFO:
+ if outLen < linux.SizeOfIPTGetinfo {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr)
+ if err != nil {
+ return nil, err
+ }
+ return info, nil
+
+ case linux.IPT_SO_GET_ENTRIES:
+ if outLen < linux.SizeOfIPTGetEntries {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen)
+ if err != nil {
+ return nil, err
+ }
+ return entries, nil
+
+ }
+ }
+
+ return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is
+ // implemented specifically for netstack.SocketVFS2 rather than
+ // commonEndpoint. commonEndpoint should be extended to support socket
+ // options where the implementation is not shared, as unix sockets need
+ // their own support for SO_TIMESTAMP.
+ if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP {
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0
+ return nil
+ }
+ if level == linux.SOL_TCP && name == linux.TCP_INQ {
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+ s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0
+ return nil
+ }
+
+ if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
+ switch name {
+ case linux.IPT_SO_SET_REPLACE:
+ if len(optVal) < linux.SizeOfIPTReplace {
+ return syserr.ErrInvalidArgument
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return syserr.ErrNoDevice
+ }
+ // Stack must be a netstack stack.
+ return netfilter.SetEntries(stack.(*Stack).Stack, optVal)
+
+ case linux.IPT_SO_SET_ADD_COUNTERS:
+ // TODO(gvisor.dev/issue/170): Counter support.
+ return nil
+ }
+ }
+
+ return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
+}
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
new file mode 100644
index 000000000..ead3b2b79
--- /dev/null
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -0,0 +1,199 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import (
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// LINT.IfChange
+
+// provider is an inet socket provider.
+type provider struct {
+ family int
+ netProto tcpip.NetworkProtocolNumber
+}
+
+// getTransportProtocol figures out transport protocol. Currently only TCP,
+// UDP, and ICMP are supported. The bool return value is true when this socket
+// is associated with a transport protocol. This is only false for SOCK_RAW,
+// IPPROTO_IP sockets.
+func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol int) (tcpip.TransportProtocolNumber, bool, *syserr.Error) {
+ switch stype {
+ case linux.SOCK_STREAM:
+ if protocol != 0 && protocol != syscall.IPPROTO_TCP {
+ return 0, true, syserr.ErrInvalidArgument
+ }
+ return tcp.ProtocolNumber, true, nil
+
+ case linux.SOCK_DGRAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_UDP:
+ return udp.ProtocolNumber, true, nil
+ case syscall.IPPROTO_ICMP:
+ return header.ICMPv4ProtocolNumber, true, nil
+ case syscall.IPPROTO_ICMPV6:
+ return header.ICMPv6ProtocolNumber, true, nil
+ }
+
+ case linux.SOCK_RAW:
+ // Raw sockets require CAP_NET_RAW.
+ creds := auth.CredentialsFromContext(ctx)
+ if !creds.HasCapability(linux.CAP_NET_RAW) {
+ return 0, true, syserr.ErrNotPermitted
+ }
+
+ switch protocol {
+ case syscall.IPPROTO_ICMP:
+ return header.ICMPv4ProtocolNumber, true, nil
+ case syscall.IPPROTO_ICMPV6:
+ return header.ICMPv6ProtocolNumber, true, nil
+ case syscall.IPPROTO_UDP:
+ return header.UDPProtocolNumber, true, nil
+ case syscall.IPPROTO_TCP:
+ return header.TCPProtocolNumber, true, nil
+ // IPPROTO_RAW signifies that the raw socket isn't assigned to
+ // a transport protocol. Users will be able to write packets'
+ // IP headers and won't receive anything.
+ case syscall.IPPROTO_RAW:
+ return tcpip.TransportProtocolNumber(0), false, nil
+ }
+ }
+ return 0, true, syserr.ErrProtocolNotSupported
+}
+
+// Socket creates a new socket object for the AF_INET, AF_INET6, or AF_PACKET
+// family.
+func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Fail right away if we don't have a stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ // Don't propagate an error here. Instead, allow the socket
+ // code to continue searching for another provider.
+ return nil, nil
+ }
+ eps, ok := stack.(*Stack)
+ if !ok {
+ return nil, nil
+ }
+
+ // Packet sockets are handled separately, since they are neither INET
+ // nor INET6 specific.
+ if p.family == linux.AF_PACKET {
+ return packetSocket(t, eps, stype, protocol)
+ }
+
+ // Figure out the transport protocol.
+ transProto, associated, err := getTransportProtocol(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+
+ // Create the endpoint.
+ var ep tcpip.Endpoint
+ var e *tcpip.Error
+ wq := &waiter.Queue{}
+ if stype == linux.SOCK_RAW {
+ ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated)
+ } else {
+ ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+
+ // Assign task to PacketOwner interface to get the UID and GID for
+ // iptables owner matching.
+ if e == nil {
+ ep.SetOwner(t)
+ }
+ }
+ if e != nil {
+ return nil, syserr.TranslateNetstackError(e)
+ }
+
+ return New(t, p.family, stype, int(transProto), wq, ep)
+}
+
+func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Packet sockets require CAP_NET_RAW.
+ creds := auth.CredentialsFromContext(t)
+ if !creds.HasCapability(linux.CAP_NET_RAW) {
+ return nil, syserr.ErrNotPermitted
+ }
+
+ // "cooked" packets don't contain link layer information.
+ var cooked bool
+ switch stype {
+ case linux.SOCK_DGRAM:
+ cooked = true
+ case linux.SOCK_RAW:
+ cooked = false
+ default:
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // protocol is passed in network byte order, but netstack wants it in
+ // host order.
+ netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+
+ wq := &waiter.Queue{}
+ ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return New(t, linux.AF_PACKET, stype, protocol, wq, ep)
+}
+
+// LINT.ThenChange(./provider_vfs2.go)
+
+// Pair just returns nil sockets (not supported).
+func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
+ return nil, nil, nil
+}
+
+// init registers socket providers for AF_INET, AF_INET6, and AF_PACKET.
+func init() {
+ // Providers backed by netstack.
+ p := []provider{
+ {
+ family: linux.AF_INET,
+ netProto: ipv4.ProtocolNumber,
+ },
+ {
+ family: linux.AF_INET6,
+ netProto: ipv6.ProtocolNumber,
+ },
+ {
+ family: linux.AF_PACKET,
+ },
+ }
+
+ for i := range p {
+ socket.RegisterProvider(p[i].family, &p[i])
+ }
+}
diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go
new file mode 100644
index 000000000..2a01143f6
--- /dev/null
+++ b/pkg/sentry/socket/netstack/provider_vfs2.go
@@ -0,0 +1,141 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// providerVFS2 is an inet socket provider.
+type providerVFS2 struct {
+ family int
+ netProto tcpip.NetworkProtocolNumber
+}
+
+// Socket creates a new socket object for the AF_INET, AF_INET6, or AF_PACKET
+// family.
+func (p *providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Fail right away if we don't have a stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ // Don't propagate an error here. Instead, allow the socket
+ // code to continue searching for another provider.
+ return nil, nil
+ }
+ eps, ok := stack.(*Stack)
+ if !ok {
+ return nil, nil
+ }
+
+ // Packet sockets are handled separately, since they are neither INET
+ // nor INET6 specific.
+ if p.family == linux.AF_PACKET {
+ return packetSocketVFS2(t, eps, stype, protocol)
+ }
+
+ // Figure out the transport protocol.
+ transProto, associated, err := getTransportProtocol(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+
+ // Create the endpoint.
+ var ep tcpip.Endpoint
+ var e *tcpip.Error
+ wq := &waiter.Queue{}
+ if stype == linux.SOCK_RAW {
+ ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated)
+ } else {
+ ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+
+ // Assign task to PacketOwner interface to get the UID and GID for
+ // iptables owner matching.
+ if e == nil {
+ ep.SetOwner(t)
+ }
+ }
+ if e != nil {
+ return nil, syserr.TranslateNetstackError(e)
+ }
+
+ return NewVFS2(t, p.family, stype, int(transProto), wq, ep)
+}
+
+func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Packet sockets require CAP_NET_RAW.
+ creds := auth.CredentialsFromContext(t)
+ if !creds.HasCapability(linux.CAP_NET_RAW) {
+ return nil, syserr.ErrNotPermitted
+ }
+
+ // "cooked" packets don't contain link layer information.
+ var cooked bool
+ switch stype {
+ case linux.SOCK_DGRAM:
+ cooked = true
+ case linux.SOCK_RAW:
+ cooked = false
+ default:
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // protocol is passed in network byte order, but netstack wants it in
+ // host order.
+ netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol)))
+
+ wq := &waiter.Queue{}
+ ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return NewVFS2(t, linux.AF_PACKET, stype, protocol, wq, ep)
+}
+
+// Pair just returns nil sockets (not supported).
+func (*providerVFS2) Pair(*kernel.Task, linux.SockType, int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ return nil, nil, nil
+}
+
+// init registers socket providers for AF_INET, AF_INET6, and AF_PACKET.
+func init() {
+ // Providers backed by netstack.
+ p := []providerVFS2{
+ {
+ family: linux.AF_INET,
+ netProto: ipv4.ProtocolNumber,
+ },
+ {
+ family: linux.AF_INET6,
+ netProto: ipv6.ProtocolNumber,
+ },
+ {
+ family: linux.AF_PACKET,
+ },
+ }
+
+ for i := range p {
+ socket.RegisterProviderVFS2(p[i].family, &p[i])
+ }
+}
diff --git a/pkg/sentry/socket/netstack/save_restore.go b/pkg/sentry/socket/netstack/save_restore.go
new file mode 100644
index 000000000..c7aaf722a
--- /dev/null
+++ b/pkg/sentry/socket/netstack/save_restore.go
@@ -0,0 +1,27 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// afterLoad is invoked by stateify.
+func (s *Stack) afterLoad() {
+ s.Stack = stack.StackFromEnv // FIXME(b/36201077)
+ if s.Stack == nil {
+ panic("can't restore without netstack/tcpip/stack.Stack")
+ }
+}
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
new file mode 100644
index 000000000..548442b96
--- /dev/null
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -0,0 +1,386 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netstack
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sentry/inet"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
+)
+
+// Stack implements inet.Stack for netstack/tcpip/stack.Stack.
+//
+// +stateify savable
+type Stack struct {
+ Stack *stack.Stack `state:"manual"`
+}
+
+// SupportsIPv6 implements Stack.SupportsIPv6.
+func (s *Stack) SupportsIPv6() bool {
+ return s.Stack.CheckNetworkProtocol(ipv6.ProtocolNumber)
+}
+
+// Interfaces implements inet.Stack.Interfaces.
+func (s *Stack) Interfaces() map[int32]inet.Interface {
+ is := make(map[int32]inet.Interface)
+ for id, ni := range s.Stack.NICInfo() {
+ var devType uint16
+ if ni.Flags.Loopback {
+ devType = linux.ARPHRD_LOOPBACK
+ }
+ is[int32(id)] = inet.Interface{
+ Name: ni.Name,
+ Addr: []byte(ni.LinkAddress),
+ Flags: uint32(nicStateFlagsToLinux(ni.Flags)),
+ DeviceType: devType,
+ MTU: ni.MTU,
+ }
+ }
+ return is
+}
+
+// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
+func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
+ nicAddrs := make(map[int32][]inet.InterfaceAddr)
+ for id, ni := range s.Stack.NICInfo() {
+ var addrs []inet.InterfaceAddr
+ for _, a := range ni.ProtocolAddresses {
+ var family uint8
+ switch a.Protocol {
+ case ipv4.ProtocolNumber:
+ family = linux.AF_INET
+ case ipv6.ProtocolNumber:
+ family = linux.AF_INET6
+ default:
+ log.Warningf("Unknown network protocol in %+v", a)
+ continue
+ }
+
+ addrs = append(addrs, inet.InterfaceAddr{
+ Family: family,
+ PrefixLen: uint8(a.AddressWithPrefix.PrefixLen),
+ Addr: []byte(a.AddressWithPrefix.Address),
+ // TODO(b/68878065): Other fields.
+ })
+ }
+ nicAddrs[int32(id)] = addrs
+ }
+ return nicAddrs
+}
+
+// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
+func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
+ var (
+ protocol tcpip.NetworkProtocolNumber
+ address tcpip.Address
+ )
+ switch addr.Family {
+ case linux.AF_INET:
+ if len(addr.Addr) < header.IPv4AddressSize {
+ return syserror.EINVAL
+ }
+ if addr.PrefixLen > header.IPv4AddressSize*8 {
+ return syserror.EINVAL
+ }
+ protocol = ipv4.ProtocolNumber
+ address = tcpip.Address(addr.Addr[:header.IPv4AddressSize])
+
+ case linux.AF_INET6:
+ if len(addr.Addr) < header.IPv6AddressSize {
+ return syserror.EINVAL
+ }
+ if addr.PrefixLen > header.IPv6AddressSize*8 {
+ return syserror.EINVAL
+ }
+ protocol = ipv6.ProtocolNumber
+ address = tcpip.Address(addr.Addr[:header.IPv6AddressSize])
+
+ default:
+ return syserror.ENOTSUP
+ }
+
+ protocolAddress := tcpip.ProtocolAddress{
+ Protocol: protocol,
+ AddressWithPrefix: tcpip.AddressWithPrefix{
+ Address: address,
+ PrefixLen: int(addr.PrefixLen),
+ },
+ }
+
+ // Attach address to interface.
+ if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil {
+ return syserr.TranslateNetstackError(err).ToError()
+ }
+
+ // Add route for local network.
+ s.Stack.AddRoute(tcpip.Route{
+ Destination: protocolAddress.AddressWithPrefix.Subnet(),
+ Gateway: "", // No gateway for local network.
+ NIC: tcpip.NICID(idx),
+ })
+ return nil
+}
+
+// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
+func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
+ var rs tcp.ReceiveBufferSizeOption
+ err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &rs)
+ return inet.TCPBufferSize{
+ Min: rs.Min,
+ Default: rs.Default,
+ Max: rs.Max,
+ }, syserr.TranslateNetstackError(err).ToError()
+}
+
+// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
+func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
+ rs := tcp.ReceiveBufferSizeOption{
+ Min: size.Min,
+ Default: size.Default,
+ Max: size.Max,
+ }
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, rs)).ToError()
+}
+
+// TCPSendBufferSize implements inet.Stack.TCPSendBufferSize.
+func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
+ var ss tcp.SendBufferSizeOption
+ err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &ss)
+ return inet.TCPBufferSize{
+ Min: ss.Min,
+ Default: ss.Default,
+ Max: ss.Max,
+ }, syserr.TranslateNetstackError(err).ToError()
+}
+
+// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
+func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
+ ss := tcp.SendBufferSizeOption{
+ Min: size.Min,
+ Default: size.Default,
+ Max: size.Max,
+ }
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, ss)).ToError()
+}
+
+// TCPSACKEnabled implements inet.Stack.TCPSACKEnabled.
+func (s *Stack) TCPSACKEnabled() (bool, error) {
+ var sack tcp.SACKEnabled
+ err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &sack)
+ return bool(sack), syserr.TranslateNetstackError(err).ToError()
+}
+
+// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
+func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError()
+}
+
+// Statistics implements inet.Stack.Statistics.
+func (s *Stack) Statistics(stat interface{}, arg string) error {
+ switch stats := stat.(type) {
+ case *inet.StatDev:
+ for _, ni := range s.Stack.NICInfo() {
+ if ni.Name != arg {
+ continue
+ }
+ // TODO(gvisor.dev/issue/2103) Support stubbed stats.
+ *stats = inet.StatDev{
+ // Receive section.
+ ni.Stats.Rx.Bytes.Value(), // bytes.
+ ni.Stats.Rx.Packets.Value(), // packets.
+ 0, // errs.
+ 0, // drop.
+ 0, // fifo.
+ 0, // frame.
+ 0, // compressed.
+ 0, // multicast.
+ // Transmit section.
+ ni.Stats.Tx.Bytes.Value(), // bytes.
+ ni.Stats.Tx.Packets.Value(), // packets.
+ 0, // errs.
+ 0, // drop.
+ 0, // fifo.
+ 0, // colls.
+ 0, // carrier.
+ 0, // compressed.
+ }
+ break
+ }
+ case *inet.StatSNMPIP:
+ ip := Metrics.IP
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
+ *stats = inet.StatSNMPIP{
+ 0, // Ip/Forwarding.
+ 0, // Ip/DefaultTTL.
+ ip.PacketsReceived.Value(), // InReceives.
+ 0, // Ip/InHdrErrors.
+ ip.InvalidDestinationAddressesReceived.Value(), // InAddrErrors.
+ 0, // Ip/ForwDatagrams.
+ 0, // Ip/InUnknownProtos.
+ 0, // Ip/InDiscards.
+ ip.PacketsDelivered.Value(), // InDelivers.
+ ip.PacketsSent.Value(), // OutRequests.
+ ip.OutgoingPacketErrors.Value(), // OutDiscards.
+ 0, // Ip/OutNoRoutes.
+ 0, // Support Ip/ReasmTimeout.
+ 0, // Support Ip/ReasmReqds.
+ 0, // Support Ip/ReasmOKs.
+ 0, // Support Ip/ReasmFails.
+ 0, // Support Ip/FragOKs.
+ 0, // Support Ip/FragFails.
+ 0, // Support Ip/FragCreates.
+ }
+ case *inet.StatSNMPICMP:
+ in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats
+ out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
+ *stats = inet.StatSNMPICMP{
+ 0, // Icmp/InMsgs.
+ Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors.
+ 0, // Icmp/InCsumErrors.
+ in.DstUnreachable.Value(), // InDestUnreachs.
+ in.TimeExceeded.Value(), // InTimeExcds.
+ in.ParamProblem.Value(), // InParmProbs.
+ in.SrcQuench.Value(), // InSrcQuenchs.
+ in.Redirect.Value(), // InRedirects.
+ in.Echo.Value(), // InEchos.
+ in.EchoReply.Value(), // InEchoReps.
+ in.Timestamp.Value(), // InTimestamps.
+ in.TimestampReply.Value(), // InTimestampReps.
+ in.InfoRequest.Value(), // InAddrMasks.
+ in.InfoReply.Value(), // InAddrMaskReps.
+ 0, // Icmp/OutMsgs.
+ Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors.
+ out.DstUnreachable.Value(), // OutDestUnreachs.
+ out.TimeExceeded.Value(), // OutTimeExcds.
+ out.ParamProblem.Value(), // OutParmProbs.
+ out.SrcQuench.Value(), // OutSrcQuenchs.
+ out.Redirect.Value(), // OutRedirects.
+ out.Echo.Value(), // OutEchos.
+ out.EchoReply.Value(), // OutEchoReps.
+ out.Timestamp.Value(), // OutTimestamps.
+ out.TimestampReply.Value(), // OutTimestampReps.
+ out.InfoRequest.Value(), // OutAddrMasks.
+ out.InfoReply.Value(), // OutAddrMaskReps.
+ }
+ case *inet.StatSNMPTCP:
+ tcp := Metrics.TCP
+ // RFC 2012 (updates 1213): SNMPv2-MIB-TCP.
+ *stats = inet.StatSNMPTCP{
+ 1, // RtoAlgorithm.
+ 200, // RtoMin.
+ 120000, // RtoMax.
+ (1<<64 - 1), // MaxConn.
+ tcp.ActiveConnectionOpenings.Value(), // ActiveOpens.
+ tcp.PassiveConnectionOpenings.Value(), // PassiveOpens.
+ tcp.FailedConnectionAttempts.Value(), // AttemptFails.
+ tcp.EstablishedResets.Value(), // EstabResets.
+ tcp.CurrentEstablished.Value(), // CurrEstab.
+ tcp.ValidSegmentsReceived.Value(), // InSegs.
+ tcp.SegmentsSent.Value(), // OutSegs.
+ tcp.Retransmits.Value(), // RetransSegs.
+ tcp.InvalidSegmentsReceived.Value(), // InErrs.
+ tcp.ResetsSent.Value(), // OutRsts.
+ tcp.ChecksumErrors.Value(), // InCsumErrors.
+ }
+ case *inet.StatSNMPUDP:
+ udp := Metrics.UDP
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
+ *stats = inet.StatSNMPUDP{
+ udp.PacketsReceived.Value(), // InDatagrams.
+ udp.UnknownPortErrors.Value(), // NoPorts.
+ 0, // Udp/InErrors.
+ udp.PacketsSent.Value(), // OutDatagrams.
+ udp.ReceiveBufferErrors.Value(), // RcvbufErrors.
+ 0, // Udp/SndbufErrors.
+ udp.ChecksumErrors.Value(), // Udp/InCsumErrors.
+ 0, // Udp/IgnoredMulti.
+ }
+ default:
+ return syserr.ErrEndpointOperation.ToError()
+ }
+ return nil
+}
+
+// RouteTable implements inet.Stack.RouteTable.
+func (s *Stack) RouteTable() []inet.Route {
+ var routeTable []inet.Route
+
+ for _, rt := range s.Stack.GetRouteTable() {
+ var family uint8
+ switch len(rt.Destination.ID()) {
+ case header.IPv4AddressSize:
+ family = linux.AF_INET
+ case header.IPv6AddressSize:
+ family = linux.AF_INET6
+ default:
+ log.Warningf("Unknown network protocol in route %+v", rt)
+ continue
+ }
+
+ routeTable = append(routeTable, inet.Route{
+ Family: family,
+ DstLen: uint8(rt.Destination.Prefix()), // The CIDR prefix for the destination.
+
+ // Always return unspecified protocol since we have no notion of
+ // protocol for routes.
+ Protocol: linux.RTPROT_UNSPEC,
+ // Set statically to LINK scope for now.
+ //
+ // TODO(gvisor.dev/issue/595): Set scope for routes.
+ Scope: linux.RT_SCOPE_LINK,
+ Type: linux.RTN_UNICAST,
+
+ DstAddr: []byte(rt.Destination.ID()),
+ OutputInterface: int32(rt.NIC),
+ GatewayAddr: []byte(rt.Gateway),
+ })
+ }
+
+ return routeTable
+}
+
+// IPTables returns the stack's iptables.
+func (s *Stack) IPTables() (*stack.IPTables, error) {
+ return s.Stack.IPTables(), nil
+}
+
+// Resume implements inet.Stack.Resume.
+func (s *Stack) Resume() {
+ s.Stack.Resume()
+}
+
+// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint {
+ return s.Stack.RegisteredEndpoints()
+}
+
+// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint {
+ return s.Stack.CleanupEndpoints()
+}
+
+// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) {
+ s.Stack.RestoreCleanupEndpoints(es)
+}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
new file mode 100644
index 000000000..fcd7f9d7f
--- /dev/null
+++ b/pkg/sentry/socket/socket.go
@@ -0,0 +1,461 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package socket provides the interfaces that need to be provided by socket
+// implementations and providers, as well as per family demultiplexing of socket
+// creation.
+package socket
+
+import (
+ "fmt"
+ "sync/atomic"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/device"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// ControlMessages represents the union of unix control messages and tcpip
+// control messages.
+type ControlMessages struct {
+ Unix transport.ControlMessages
+ IP tcpip.ControlMessages
+}
+
+// Release releases Unix domain socket credentials and rights.
+func (c *ControlMessages) Release() {
+ c.Unix.Release()
+}
+
+// Socket is an interface combining fs.FileOperations and SocketOps,
+// representing a VFS1 socket file.
+type Socket interface {
+ fs.FileOperations
+ SocketOps
+}
+
+// SocketVFS2 is an interface combining vfs.FileDescription and SocketOps,
+// representing a VFS2 socket file.
+type SocketVFS2 interface {
+ vfs.FileDescriptionImpl
+ SocketOps
+}
+
+// SocketOps is the interface containing socket syscalls used by the syscall
+// layer to redirect them to the appropriate implementation.
+//
+// It is implemented by both Socket and SocketVFS2.
+type SocketOps interface {
+ // Connect implements the connect(2) linux syscall.
+ Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error
+
+ // Accept implements the accept4(2) linux syscall.
+ // Returns fd, real peer address length and error. Real peer address
+ // length is only set if len(peer) > 0.
+ Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error)
+
+ // Bind implements the bind(2) linux syscall.
+ Bind(t *kernel.Task, sockaddr []byte) *syserr.Error
+
+ // Listen implements the listen(2) linux syscall.
+ Listen(t *kernel.Task, backlog int) *syserr.Error
+
+ // Shutdown implements the shutdown(2) linux syscall.
+ Shutdown(t *kernel.Task, how int) *syserr.Error
+
+ // GetSockOpt implements the getsockopt(2) linux syscall.
+ GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error)
+
+ // SetSockOpt implements the setsockopt(2) linux syscall.
+ SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error
+
+ // GetSockName implements the getsockname(2) linux syscall.
+ //
+ // addrLen is the address length to be returned to the application, not
+ // necessarily the actual length of the address.
+ GetSockName(t *kernel.Task) (addr linux.SockAddr, addrLen uint32, err *syserr.Error)
+
+ // GetPeerName implements the getpeername(2) linux syscall.
+ //
+ // addrLen is the address length to be returned to the application, not
+ // necessarily the actual length of the address.
+ GetPeerName(t *kernel.Task) (addr linux.SockAddr, addrLen uint32, err *syserr.Error)
+
+ // RecvMsg implements the recvmsg(2) linux syscall.
+ //
+ // senderAddrLen is the address length to be returned to the application,
+ // not necessarily the actual length of the address.
+ //
+ // flags control how RecvMsg should be completed. msgFlags indicate how
+ // the RecvMsg call was completed. Note that control message truncation
+ // may still be required even if the MSG_CTRUNC bit is not set in
+ // msgFlags. In that case, the caller should set MSG_CTRUNC appropriately.
+ //
+ // If err != nil, the recv was not successful.
+ RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages ControlMessages, err *syserr.Error)
+
+ // SendMsg implements the sendmsg(2) linux syscall. SendMsg does not take
+ // ownership of the ControlMessage on error.
+ //
+ // If n > 0, err will either be nil or an error from t.Block.
+ SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages ControlMessages) (n int, err *syserr.Error)
+
+ // SetRecvTimeout sets the timeout (in ns) for recv operations. Zero means
+ // no timeout, and negative means DONTWAIT.
+ SetRecvTimeout(nanoseconds int64)
+
+ // RecvTimeout gets the current timeout (in ns) for recv operations. Zero
+ // means no timeout, and negative means DONTWAIT.
+ RecvTimeout() int64
+
+ // SetSendTimeout sets the timeout (in ns) for send operations. Zero means
+ // no timeout, and negative means DONTWAIT.
+ SetSendTimeout(nanoseconds int64)
+
+ // SendTimeout gets the current timeout (in ns) for send operations. Zero
+ // means no timeout, and negative means DONTWAIT.
+ SendTimeout() int64
+
+ // State returns the current state of the socket, as represented by Linux in
+ // procfs. The returned state value is protocol-specific.
+ State() uint32
+
+ // Type returns the family, socket type and protocol of the socket.
+ Type() (family int, skType linux.SockType, protocol int)
+}
+
+// Provider is the interface implemented by providers of sockets for specific
+// address families (e.g., AF_INET).
+type Provider interface {
+ // Socket creates a new socket.
+ //
+ // If a nil Socket _and_ a nil error is returned, it means that the
+ // protocol is not supported. A non-nil error should only be returned
+ // if the protocol is supported, but an error occurs during creation.
+ Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error)
+
+ // Pair creates a pair of connected sockets.
+ //
+ // See Socket for error information.
+ Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error)
+}
+
+// families holds a map of all known address families and their providers.
+var families = make(map[int][]Provider)
+
+// RegisterProvider registers the provider of a given address family so that
+// sockets of that type can be created via socket() and/or socketpair()
+// syscalls.
+//
+// This should only be called during the initialization of the address family.
+func RegisterProvider(family int, provider Provider) {
+ families[family] = append(families[family], provider)
+}
+
+// New creates a new socket with the given family, type and protocol.
+func New(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ for _, p := range families[family] {
+ s, err := p.Socket(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+ if s != nil {
+ t.Kernel().RecordSocket(s)
+ return s, nil
+ }
+ }
+
+ return nil, syserr.ErrAddressFamilyNotSupported
+}
+
+// Pair creates a new connected socket pair with the given family, type and
+// protocol.
+func Pair(t *kernel.Task, family int, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+ providers, ok := families[family]
+ if !ok {
+ return nil, nil, syserr.ErrAddressFamilyNotSupported
+ }
+
+ for _, p := range providers {
+ s1, s2, err := p.Pair(t, stype, protocol)
+ if err != nil {
+ return nil, nil, err
+ }
+ if s1 != nil && s2 != nil {
+ k := t.Kernel()
+ k.RecordSocket(s1)
+ k.RecordSocket(s2)
+ return s1, s2, nil
+ }
+ }
+
+ return nil, nil, syserr.ErrSocketNotSupported
+}
+
+// NewDirent returns a sockfs fs.Dirent that resides on device d.
+func NewDirent(ctx context.Context, d *device.Device) *fs.Dirent {
+ ino := d.NextIno()
+ iops := &fsutil.SimpleFileInode{
+ InodeSimpleAttributes: fsutil.NewInodeSimpleAttributes(ctx, fs.FileOwnerFromContext(ctx), fs.FilePermissions{
+ User: fs.PermMask{Read: true, Write: true},
+ }, linux.SOCKFS_MAGIC),
+ }
+ inode := fs.NewInode(ctx, iops, fs.NewPseudoMountSource(ctx), fs.StableAttr{
+ Type: fs.Socket,
+ DeviceID: d.DeviceID(),
+ InodeID: ino,
+ BlockSize: usermem.PageSize,
+ })
+
+ // Dirent name matches net/socket.c:sockfs_dname.
+ return fs.NewDirent(ctx, inode, fmt.Sprintf("socket:[%d]", ino))
+}
+
+// ProviderVFS2 is the vfs2 interface implemented by providers of sockets for
+// specific address families (e.g., AF_INET).
+type ProviderVFS2 interface {
+ // Socket creates a new socket.
+ //
+ // If a nil Socket _and_ a nil error is returned, it means that the
+ // protocol is not supported. A non-nil error should only be returned
+ // if the protocol is supported, but an error occurs during creation.
+ Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error)
+
+ // Pair creates a pair of connected sockets.
+ //
+ // See Socket for error information.
+ Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error)
+}
+
+// familiesVFS2 holds a map of all known address families and their providers.
+var familiesVFS2 = make(map[int][]ProviderVFS2)
+
+// RegisterProviderVFS2 registers the provider of a given address family so that
+// sockets of that type can be created via socket() and/or socketpair()
+// syscalls.
+//
+// This should only be called during the initialization of the address family.
+func RegisterProviderVFS2(family int, provider ProviderVFS2) {
+ familiesVFS2[family] = append(familiesVFS2[family], provider)
+}
+
+// NewVFS2 creates a new socket with the given family, type and protocol.
+func NewVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ for _, p := range familiesVFS2[family] {
+ s, err := p.Socket(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+ if s != nil {
+ t.Kernel().RecordSocketVFS2(s)
+ return s, nil
+ }
+ }
+
+ return nil, syserr.ErrAddressFamilyNotSupported
+}
+
+// PairVFS2 creates a new connected socket pair with the given family, type and
+// protocol.
+func PairVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ providers, ok := familiesVFS2[family]
+ if !ok {
+ return nil, nil, syserr.ErrAddressFamilyNotSupported
+ }
+
+ for _, p := range providers {
+ s1, s2, err := p.Pair(t, stype, protocol)
+ if err != nil {
+ return nil, nil, err
+ }
+ if s1 != nil && s2 != nil {
+ k := t.Kernel()
+ k.RecordSocketVFS2(s1)
+ k.RecordSocketVFS2(s2)
+ return s1, s2, nil
+ }
+ }
+
+ return nil, nil, syserr.ErrSocketNotSupported
+}
+
+// SendReceiveTimeout stores timeouts for send and receive calls.
+//
+// It is meant to be embedded into Socket implementations to help satisfy the
+// interface.
+//
+// Care must be taken when copying SendReceiveTimeout as it contains atomic
+// variables.
+//
+// +stateify savable
+type SendReceiveTimeout struct {
+ // send is length of the send timeout in nanoseconds.
+ //
+ // send must be accessed atomically.
+ send int64
+
+ // recv is length of the receive timeout in nanoseconds.
+ //
+ // recv must be accessed atomically.
+ recv int64
+}
+
+// SetRecvTimeout implements Socket.SetRecvTimeout.
+func (to *SendReceiveTimeout) SetRecvTimeout(nanoseconds int64) {
+ atomic.StoreInt64(&to.recv, nanoseconds)
+}
+
+// RecvTimeout implements Socket.RecvTimeout.
+func (to *SendReceiveTimeout) RecvTimeout() int64 {
+ return atomic.LoadInt64(&to.recv)
+}
+
+// SetSendTimeout implements Socket.SetSendTimeout.
+func (to *SendReceiveTimeout) SetSendTimeout(nanoseconds int64) {
+ atomic.StoreInt64(&to.send, nanoseconds)
+}
+
+// SendTimeout implements Socket.SendTimeout.
+func (to *SendReceiveTimeout) SendTimeout() int64 {
+ return atomic.LoadInt64(&to.send)
+}
+
+// GetSockOptEmitUnimplementedEvent emits unimplemented event if name is valid.
+// It contains names that are valid for GetSockOpt when level is SOL_SOCKET.
+func GetSockOptEmitUnimplementedEvent(t *kernel.Task, name int) {
+ switch name {
+ case linux.SO_ACCEPTCONN,
+ linux.SO_BPF_EXTENSIONS,
+ linux.SO_COOKIE,
+ linux.SO_DOMAIN,
+ linux.SO_ERROR,
+ linux.SO_GET_FILTER,
+ linux.SO_INCOMING_NAPI_ID,
+ linux.SO_MEMINFO,
+ linux.SO_PEERCRED,
+ linux.SO_PEERGROUPS,
+ linux.SO_PEERNAME,
+ linux.SO_PEERSEC,
+ linux.SO_PROTOCOL,
+ linux.SO_SNDLOWAT,
+ linux.SO_TYPE:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ default:
+ emitUnimplementedEvent(t, name)
+ }
+}
+
+// SetSockOptEmitUnimplementedEvent emits unimplemented event if name is valid.
+// It contains names that are valid for SetSockOpt when level is SOL_SOCKET.
+func SetSockOptEmitUnimplementedEvent(t *kernel.Task, name int) {
+ switch name {
+ case linux.SO_ATTACH_BPF,
+ linux.SO_ATTACH_FILTER,
+ linux.SO_ATTACH_REUSEPORT_CBPF,
+ linux.SO_ATTACH_REUSEPORT_EBPF,
+ linux.SO_CNX_ADVICE,
+ linux.SO_DETACH_FILTER,
+ linux.SO_RCVBUFFORCE,
+ linux.SO_SNDBUFFORCE:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+
+ default:
+ emitUnimplementedEvent(t, name)
+ }
+}
+
+// emitUnimplementedEvent emits unimplemented event if name is valid. It
+// contains names that are common between Get and SetSocketOpt when level is
+// SOL_SOCKET.
+func emitUnimplementedEvent(t *kernel.Task, name int) {
+ switch name {
+ case linux.SO_BINDTODEVICE,
+ linux.SO_BROADCAST,
+ linux.SO_BSDCOMPAT,
+ linux.SO_BUSY_POLL,
+ linux.SO_DEBUG,
+ linux.SO_DONTROUTE,
+ linux.SO_INCOMING_CPU,
+ linux.SO_KEEPALIVE,
+ linux.SO_LINGER,
+ linux.SO_LOCK_FILTER,
+ linux.SO_MARK,
+ linux.SO_MAX_PACING_RATE,
+ linux.SO_NOFCS,
+ linux.SO_OOBINLINE,
+ linux.SO_PASSCRED,
+ linux.SO_PASSSEC,
+ linux.SO_PEEK_OFF,
+ linux.SO_PRIORITY,
+ linux.SO_RCVBUF,
+ linux.SO_RCVLOWAT,
+ linux.SO_RCVTIMEO,
+ linux.SO_REUSEADDR,
+ linux.SO_REUSEPORT,
+ linux.SO_RXQ_OVFL,
+ linux.SO_SELECT_ERR_QUEUE,
+ linux.SO_SNDBUF,
+ linux.SO_SNDTIMEO,
+ linux.SO_TIMESTAMP,
+ linux.SO_TIMESTAMPING,
+ linux.SO_TIMESTAMPNS,
+ linux.SO_TXTIME,
+ linux.SO_WIFI_STATUS,
+ linux.SO_ZEROCOPY:
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ }
+}
+
+// UnmarshalSockAddr unmarshals memory representing a struct sockaddr to one of
+// the ABI socket address types.
+//
+// Precondition: data must be long enough to represent a socket address of the
+// given family.
+func UnmarshalSockAddr(family int, data []byte) linux.SockAddr {
+ switch family {
+ case syscall.AF_INET:
+ var addr linux.SockAddrInet
+ binary.Unmarshal(data[:syscall.SizeofSockaddrInet4], usermem.ByteOrder, &addr)
+ return &addr
+ case syscall.AF_INET6:
+ var addr linux.SockAddrInet6
+ binary.Unmarshal(data[:syscall.SizeofSockaddrInet6], usermem.ByteOrder, &addr)
+ return &addr
+ case syscall.AF_UNIX:
+ var addr linux.SockAddrUnix
+ binary.Unmarshal(data[:syscall.SizeofSockaddrUnix], usermem.ByteOrder, &addr)
+ return &addr
+ case syscall.AF_NETLINK:
+ var addr linux.SockAddrNetlink
+ binary.Unmarshal(data[:syscall.SizeofSockaddrNetlink], usermem.ByteOrder, &addr)
+ return &addr
+ default:
+ panic(fmt.Sprintf("Unsupported socket family %v", family))
+ }
+}
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
new file mode 100644
index 000000000..cca5e70f1
--- /dev/null
+++ b/pkg/sentry/socket/unix/BUILD
@@ -0,0 +1,39 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "unix",
+ srcs = [
+ "device.go",
+ "io.go",
+ "unix.go",
+ "unix_vfs2.go",
+ ],
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/fspath",
+ "//pkg/refs",
+ "//pkg/safemem",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fs/lock",
+ "//pkg/sentry/fsimpl/sockfs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/socket/netstack",
+ "//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/usermem",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/socket/unix/device.go b/pkg/sentry/socket/unix/device.go
new file mode 100644
index 000000000..db01ac4c9
--- /dev/null
+++ b/pkg/sentry/socket/unix/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import "gvisor.dev/gvisor/pkg/sentry/device"
+
+// unixSocketDevice is the unix socket virtual device.
+var unixSocketDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go
new file mode 100644
index 000000000..129949990
--- /dev/null
+++ b/pkg/sentry/socket/unix/io.go
@@ -0,0 +1,111 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import (
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/safemem"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/tcpip"
+)
+
+// EndpointWriter implements safemem.Writer that writes to a transport.Endpoint.
+//
+// EndpointWriter is not thread-safe.
+type EndpointWriter struct {
+ Ctx context.Context
+
+ // Endpoint is the transport.Endpoint to write to.
+ Endpoint transport.Endpoint
+
+ // Control is the control messages to send.
+ Control transport.ControlMessages
+
+ // To is the endpoint to send to. May be nil.
+ To transport.BoundEndpoint
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (w *EndpointWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ return safemem.FromVecWriterFunc{func(bufs [][]byte) (int64, error) {
+ n, err := w.Endpoint.SendMsg(w.Ctx, bufs, w.Control, w.To)
+ if err != nil {
+ return int64(n), err.ToError()
+ }
+ return int64(n), nil
+ }}.WriteFromBlocks(srcs)
+}
+
+// EndpointReader implements safemem.Reader that reads from a
+// transport.Endpoint.
+//
+// EndpointReader is not thread-safe.
+type EndpointReader struct {
+ Ctx context.Context
+
+ // Endpoint is the transport.Endpoint to read from.
+ Endpoint transport.Endpoint
+
+ // Creds indicates if credential control messages are requested.
+ Creds bool
+
+ // NumRights is the number of SCM_RIGHTS FDs requested.
+ NumRights int
+
+ // Peek indicates that the data should not be consumed from the
+ // endpoint.
+ Peek bool
+
+ // MsgSize is the size of the message that was read from. For stream
+ // sockets, it is the amount read.
+ MsgSize int64
+
+ // From, if not nil, will be set with the address read from.
+ From *tcpip.FullAddress
+
+ // Control contains the received control messages.
+ Control transport.ControlMessages
+
+ // ControlTrunc indicates that SCM_RIGHTS FDs were discarded based on
+ // the value of NumRights.
+ ControlTrunc bool
+}
+
+// Truncate calls RecvMsg on the endpoint without writing to a destination.
+func (r *EndpointReader) Truncate() error {
+ // Ignore bytes read since it will always be zero.
+ _, ms, c, ct, err := r.Endpoint.RecvMsg(r.Ctx, [][]byte{}, r.Creds, r.NumRights, r.Peek, r.From)
+ r.Control = c
+ r.ControlTrunc = ct
+ r.MsgSize = ms
+ if err != nil {
+ return err.ToError()
+ }
+ return nil
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) {
+ n, ms, c, ct, err := r.Endpoint.RecvMsg(r.Ctx, bufs, r.Creds, r.NumRights, r.Peek, r.From)
+ r.Control = c
+ r.ControlTrunc = ct
+ r.MsgSize = ms
+ if err != nil {
+ return int64(n), err.ToError()
+ }
+ return int64(n), nil
+ }}.ReadToBlocks(dsts)
+}
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
new file mode 100644
index 000000000..c708b6030
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -0,0 +1,41 @@
+load("//tools:defs.bzl", "go_library")
+load("//tools/go_generics:defs.bzl", "go_template_instance")
+
+package(licenses = ["notice"])
+
+go_template_instance(
+ name = "transport_message_list",
+ out = "transport_message_list.go",
+ package = "transport",
+ prefix = "message",
+ template = "//pkg/ilist:generic_list",
+ types = {
+ "Element": "*message",
+ "Linker": "*message",
+ },
+)
+
+go_library(
+ name = "transport",
+ srcs = [
+ "connectioned.go",
+ "connectioned_state.go",
+ "connectionless.go",
+ "queue.go",
+ "transport_message_list.go",
+ "unix.go",
+ ],
+ visibility = ["//:sandbox"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/context",
+ "//pkg/ilist",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/sync",
+ "//pkg/syserr",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
new file mode 100644
index 000000000..a1e49cc57
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -0,0 +1,486 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// UniqueIDProvider generates a sequence of unique identifiers useful for,
+// among other things, lock ordering.
+type UniqueIDProvider interface {
+ // UniqueID returns a new unique identifier.
+ UniqueID() uint64
+}
+
+// A ConnectingEndpoint is a connectioned unix endpoint that is attempting to
+// establish a bidirectional connection with a BoundEndpoint.
+type ConnectingEndpoint interface {
+ // ID returns the endpoint's globally unique identifier. This identifier
+ // must be used to determine locking order if more than one endpoint is
+ // to be locked in the same codepath. The endpoint with the smaller
+ // identifier must be locked before endpoints with larger identifiers.
+ ID() uint64
+
+ // Passcred implements socket.Credentialer.Passcred.
+ Passcred() bool
+
+ // Type returns the socket type, typically either SockStream or
+ // SockSeqpacket. The connection attempt must be aborted if this
+ // value doesn't match the ConnectableEndpoint's type.
+ Type() linux.SockType
+
+ // GetLocalAddress returns the bound path.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Locker protects the following methods. While locked, only the holder of
+ // the lock can change the return value of the protected methods.
+ sync.Locker
+
+ // Connected returns true iff the ConnectingEndpoint is in the connected
+ // state. ConnectingEndpoints can only be connected to a single endpoint,
+ // so the connection attempt must be aborted if this returns true.
+ Connected() bool
+
+ // Listening returns true iff the ConnectingEndpoint is in the listening
+ // state. ConnectingEndpoints cannot make connections while listening, so
+ // the connection attempt must be aborted if this returns true.
+ Listening() bool
+
+ // WaiterQueue returns a pointer to the endpoint's waiter queue.
+ WaiterQueue() *waiter.Queue
+}
+
+// connectionedEndpoint is a Unix-domain connected or connectable endpoint and implements
+// ConnectingEndpoint, ConnectableEndpoint and tcpip.Endpoint.
+//
+// connectionedEndpoints must be in connected state in order to transfer data.
+//
+// This implementation includes STREAM and SEQPACKET Unix sockets created with
+// socket(2), accept(2) or socketpair(2) and dgram unix sockets created with
+// socketpair(2). See unix_connectionless.go for the implementation of DGRAM
+// Unix sockets created with socket(2).
+//
+// The state is much simpler than a TCP endpoint, so it is not encoded
+// explicitly. Instead we enforce the following invariants:
+//
+// receiver != nil, connected != nil => connected.
+// path != "" && acceptedChan == nil => bound, not listening.
+// path != "" && acceptedChan != nil => bound and listening.
+//
+// Only one of these will be true at any moment.
+//
+// +stateify savable
+type connectionedEndpoint struct {
+ baseEndpoint
+
+ // id is the unique endpoint identifier. This is used exclusively for
+ // lock ordering within connect.
+ id uint64
+
+ // idGenerator is used to generate new unique endpoint identifiers.
+ idGenerator UniqueIDProvider
+
+ // stype is used by connecting sockets to ensure that they are the
+ // same type. The value is typically either tcpip.SockSeqpacket or
+ // tcpip.SockStream.
+ stype linux.SockType
+
+ // acceptedChan is per the TCP endpoint implementation. Note that the
+ // sockets in this channel are _already in the connected state_, and
+ // have another associated connectionedEndpoint.
+ //
+ // If nil, then no listen call has been made.
+ acceptedChan chan *connectionedEndpoint `state:".([]*connectionedEndpoint)"`
+}
+
+var (
+ _ = BoundEndpoint((*connectionedEndpoint)(nil))
+ _ = Endpoint((*connectionedEndpoint)(nil))
+)
+
+// NewConnectioned creates a new unbound connectionedEndpoint.
+func NewConnectioned(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) Endpoint {
+ return &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+}
+
+// NewPair allocates a new pair of connected unix-domain connectionedEndpoints.
+func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (Endpoint, Endpoint) {
+ a := &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+ b := &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: &waiter.Queue{}},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+
+ q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
+ q1.EnableLeakCheck("transport.queue")
+ q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit}
+ q2.EnableLeakCheck("transport.queue")
+
+ if stype == linux.SOCK_STREAM {
+ a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}}
+ b.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q2}}
+ } else {
+ a.receiver = &queueReceiver{q1}
+ b.receiver = &queueReceiver{q2}
+ }
+
+ q2.IncRef()
+ a.connected = &connectedEndpoint{
+ endpoint: b,
+ writeQueue: q2,
+ }
+ q1.IncRef()
+ b.connected = &connectedEndpoint{
+ endpoint: a,
+ writeQueue: q1,
+ }
+
+ return a, b
+}
+
+// NewExternal creates a new externally backed Endpoint. It behaves like a
+// socketpair.
+func NewExternal(ctx context.Context, stype linux.SockType, uid UniqueIDProvider, queue *waiter.Queue, receiver Receiver, connected ConnectedEndpoint) Endpoint {
+ return &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{Queue: queue, receiver: receiver, connected: connected},
+ id: uid.UniqueID(),
+ idGenerator: uid,
+ stype: stype,
+ }
+}
+
+// ID implements ConnectingEndpoint.ID.
+func (e *connectionedEndpoint) ID() uint64 {
+ return e.id
+}
+
+// Type implements ConnectingEndpoint.Type and Endpoint.Type.
+func (e *connectionedEndpoint) Type() linux.SockType {
+ return e.stype
+}
+
+// WaiterQueue implements ConnectingEndpoint.WaiterQueue.
+func (e *connectionedEndpoint) WaiterQueue() *waiter.Queue {
+ return e.Queue
+}
+
+// isBound returns true iff the connectionedEndpoint is bound (but not
+// listening).
+func (e *connectionedEndpoint) isBound() bool {
+ return e.path != "" && e.acceptedChan == nil
+}
+
+// Listening implements ConnectingEndpoint.Listening.
+func (e *connectionedEndpoint) Listening() bool {
+ return e.acceptedChan != nil
+}
+
+// Close puts the connectionedEndpoint in a closed state and frees all
+// resources associated with it.
+//
+// The socket will be a fresh state after a call to close and may be reused.
+// That is, close may be used to "unbind" or "disconnect" the socket in error
+// paths.
+func (e *connectionedEndpoint) Close() {
+ e.Lock()
+ var c ConnectedEndpoint
+ var r Receiver
+ switch {
+ case e.Connected():
+ e.connected.CloseSend()
+ e.receiver.CloseRecv()
+ // Still have unread data? If yes, we set this into the write
+ // end so that the peer can get ECONNRESET) when it does read.
+ if e.receiver.RecvQueuedSize() > 0 {
+ e.connected.CloseUnread()
+ }
+ c = e.connected
+ r = e.receiver
+ e.connected = nil
+ e.receiver = nil
+ case e.isBound():
+ e.path = ""
+ case e.Listening():
+ close(e.acceptedChan)
+ for n := range e.acceptedChan {
+ n.Close()
+ }
+ e.acceptedChan = nil
+ e.path = ""
+ }
+ e.Unlock()
+ if c != nil {
+ c.CloseNotify()
+ c.Release()
+ }
+ if r != nil {
+ r.CloseNotify()
+ r.Release()
+ }
+}
+
+// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
+func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
+ if ce.Type() != e.stype {
+ return syserr.ErrWrongProtocolForSocket
+ }
+
+ // Check if ce is e to avoid a deadlock.
+ if ce, ok := ce.(*connectionedEndpoint); ok && ce == e {
+ return syserr.ErrInvalidEndpointState
+ }
+
+ // Do a dance to safely acquire locks on both endpoints.
+ if e.id < ce.ID() {
+ e.Lock()
+ ce.Lock()
+ } else {
+ ce.Lock()
+ e.Lock()
+ }
+
+ // Check connecting state.
+ if ce.Connected() {
+ e.Unlock()
+ ce.Unlock()
+ return syserr.ErrAlreadyConnected
+ }
+ if ce.Listening() {
+ e.Unlock()
+ ce.Unlock()
+ return syserr.ErrInvalidEndpointState
+ }
+
+ // Check bound state.
+ if !e.Listening() {
+ e.Unlock()
+ ce.Unlock()
+ return syserr.ErrConnectionRefused
+ }
+
+ // Create a newly bound connectionedEndpoint.
+ ne := &connectionedEndpoint{
+ baseEndpoint: baseEndpoint{
+ path: e.path,
+ Queue: &waiter.Queue{},
+ },
+ id: e.idGenerator.UniqueID(),
+ idGenerator: e.idGenerator,
+ stype: e.stype,
+ }
+
+ readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit}
+ readQueue.EnableLeakCheck("transport.queue")
+ ne.connected = &connectedEndpoint{
+ endpoint: ce,
+ writeQueue: readQueue,
+ }
+
+ writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit}
+ writeQueue.EnableLeakCheck("transport.queue")
+ if e.stype == linux.SOCK_STREAM {
+ ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
+ } else {
+ ne.receiver = &queueReceiver{readQueue: writeQueue}
+ }
+
+ select {
+ case e.acceptedChan <- ne:
+ // Commit state.
+ writeQueue.IncRef()
+ connected := &connectedEndpoint{
+ endpoint: ne,
+ writeQueue: writeQueue,
+ }
+ readQueue.IncRef()
+ if e.stype == linux.SOCK_STREAM {
+ returnConnect(&streamQueueReceiver{queueReceiver: queueReceiver{readQueue: readQueue}}, connected)
+ } else {
+ returnConnect(&queueReceiver{readQueue: readQueue}, connected)
+ }
+
+ // Notify can deadlock if we are holding these locks.
+ e.Unlock()
+ ce.Unlock()
+
+ // Notify on both ends.
+ e.Notify(waiter.EventIn)
+ ce.WaiterQueue().Notify(waiter.EventOut)
+
+ return nil
+ default:
+ // Busy; return ECONNREFUSED per spec.
+ ne.Close()
+ e.Unlock()
+ ce.Unlock()
+ return syserr.ErrConnectionRefused
+ }
+}
+
+// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
+func (e *connectionedEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) {
+ return nil, syserr.ErrConnectionRefused
+}
+
+// Connect attempts to directly connect to another Endpoint.
+// Implements Endpoint.Connect.
+func (e *connectionedEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error {
+ returnConnect := func(r Receiver, ce ConnectedEndpoint) {
+ e.receiver = r
+ e.connected = ce
+ }
+
+ return server.BidirectionalConnect(ctx, e, returnConnect)
+}
+
+// Listen starts listening on the connection.
+func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error {
+ e.Lock()
+ defer e.Unlock()
+ if e.Listening() {
+ // Adjust the size of the channel iff we can fix existing
+ // pending connections into the new one.
+ if len(e.acceptedChan) > backlog {
+ return syserr.ErrInvalidEndpointState
+ }
+ origChan := e.acceptedChan
+ e.acceptedChan = make(chan *connectionedEndpoint, backlog)
+ close(origChan)
+ for ep := range origChan {
+ e.acceptedChan <- ep
+ }
+ return nil
+ }
+ if !e.isBound() {
+ return syserr.ErrInvalidEndpointState
+ }
+
+ // Normal case.
+ e.acceptedChan = make(chan *connectionedEndpoint, backlog)
+ return nil
+}
+
+// Accept accepts a new connection.
+func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) {
+ e.Lock()
+ defer e.Unlock()
+
+ if !e.Listening() {
+ return nil, syserr.ErrInvalidEndpointState
+ }
+
+ select {
+ case ne := <-e.acceptedChan:
+ return ne, nil
+
+ default:
+ // Nothing left.
+ return nil, syserr.ErrWouldBlock
+ }
+}
+
+// Bind binds the connection.
+//
+// For Unix connectionedEndpoints, this _only sets the address associated with
+// the socket_. Work associated with sockets in the filesystem or finding those
+// sockets must be done by a higher level.
+//
+// Bind will fail only if the socket is connected, bound or the passed address
+// is invalid (the empty string).
+func (e *connectionedEndpoint) Bind(addr tcpip.FullAddress, commit func() *syserr.Error) *syserr.Error {
+ e.Lock()
+ defer e.Unlock()
+ if e.isBound() || e.Listening() {
+ return syserr.ErrAlreadyBound
+ }
+ if addr.Addr == "" {
+ // The empty string is not permitted.
+ return syserr.ErrBadLocalAddress
+ }
+ if commit != nil {
+ if err := commit(); err != nil {
+ return err
+ }
+ }
+
+ // Save the bound address.
+ e.path = string(addr.Addr)
+ return nil
+}
+
+// SendMsg writes data and a control message to the endpoint's peer.
+// This method does not block if the data cannot be written.
+func (e *connectionedEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) {
+ // Stream sockets do not support specifying the endpoint. Seqpacket
+ // sockets ignore the passed endpoint.
+ if e.stype == linux.SOCK_STREAM && to != nil {
+ return 0, syserr.ErrNotSupported
+ }
+ return e.baseEndpoint.SendMsg(ctx, data, c, to)
+}
+
+// Readiness returns the current readiness of the connectionedEndpoint. For
+// example, if waiter.EventIn is set, the connectionedEndpoint is immediately
+// readable.
+func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ e.Lock()
+ defer e.Unlock()
+
+ ready := waiter.EventMask(0)
+ switch {
+ case e.Connected():
+ if mask&waiter.EventIn != 0 && e.receiver.Readable() {
+ ready |= waiter.EventIn
+ }
+ if mask&waiter.EventOut != 0 && e.connected.Writable() {
+ ready |= waiter.EventOut
+ }
+ case e.Listening():
+ if mask&waiter.EventIn != 0 && len(e.acceptedChan) > 0 {
+ ready |= waiter.EventIn
+ }
+ }
+
+ return ready
+}
+
+// State implements socket.Socket.State.
+func (e *connectionedEndpoint) State() uint32 {
+ e.Lock()
+ defer e.Unlock()
+
+ if e.Connected() {
+ return linux.SS_CONNECTED
+ }
+ return linux.SS_UNCONNECTED
+}
diff --git a/pkg/sentry/socket/unix/transport/connectioned_state.go b/pkg/sentry/socket/unix/transport/connectioned_state.go
new file mode 100644
index 000000000..7e02a5db8
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/connectioned_state.go
@@ -0,0 +1,53 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+// saveAcceptedChan is invoked by stateify.
+func (e *connectionedEndpoint) saveAcceptedChan() []*connectionedEndpoint {
+ // If acceptedChan is nil (i.e. we are not listening) then we will save nil.
+ // Otherwise we create a (possibly empty) slice of the values in acceptedChan and
+ // save that.
+ var acceptedSlice []*connectionedEndpoint
+ if e.acceptedChan != nil {
+ // Swap out acceptedChan with a new empty channel of the same capacity.
+ saveChan := e.acceptedChan
+ e.acceptedChan = make(chan *connectionedEndpoint, cap(saveChan))
+
+ // Create a new slice with the same len and capacity as the channel.
+ acceptedSlice = make([]*connectionedEndpoint, len(saveChan), cap(saveChan))
+ // Drain acceptedChan into saveSlice, and fill up the new acceptChan at the
+ // same time.
+ for i := range acceptedSlice {
+ ep := <-saveChan
+ acceptedSlice[i] = ep
+ e.acceptedChan <- ep
+ }
+ close(saveChan)
+ }
+ return acceptedSlice
+}
+
+// loadAcceptedChan is invoked by stateify.
+func (e *connectionedEndpoint) loadAcceptedChan(acceptedSlice []*connectionedEndpoint) {
+ // If acceptedSlice is nil, then acceptedChan should also be nil.
+ if acceptedSlice != nil {
+ // Otherwise, create a new channel with the same capacity as acceptedSlice.
+ e.acceptedChan = make(chan *connectionedEndpoint, cap(acceptedSlice))
+ // Seed the channel with values from acceptedSlice.
+ for _, ep := range acceptedSlice {
+ e.acceptedChan <- ep
+ }
+ }
+}
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
new file mode 100644
index 000000000..4b06d63ac
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -0,0 +1,218 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// connectionlessEndpoint is a unix endpoint for unix sockets that support operating in
+// a connectionless fashon.
+//
+// Specifically, this means datagram unix sockets not created with
+// socketpair(2).
+//
+// +stateify savable
+type connectionlessEndpoint struct {
+ baseEndpoint
+}
+
+var (
+ _ = BoundEndpoint((*connectionlessEndpoint)(nil))
+ _ = Endpoint((*connectionlessEndpoint)(nil))
+)
+
+// NewConnectionless creates a new unbound dgram endpoint.
+func NewConnectionless(ctx context.Context) Endpoint {
+ ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}}
+ q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}
+ q.EnableLeakCheck("transport.queue")
+ ep.receiver = &queueReceiver{readQueue: &q}
+ return ep
+}
+
+// isBound returns true iff the endpoint is bound.
+func (e *connectionlessEndpoint) isBound() bool {
+ return e.path != ""
+}
+
+// Close puts the endpoint in a closed state and frees all resources associated
+// with it.
+func (e *connectionlessEndpoint) Close() {
+ e.Lock()
+ if e.connected != nil {
+ e.connected.Release()
+ e.connected = nil
+ }
+
+ if e.isBound() {
+ e.path = ""
+ }
+
+ e.receiver.CloseRecv()
+ r := e.receiver
+ e.receiver = nil
+ e.Unlock()
+
+ r.CloseNotify()
+ r.Release()
+}
+
+// BidirectionalConnect implements BoundEndpoint.BidirectionalConnect.
+func (e *connectionlessEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error {
+ return syserr.ErrConnectionRefused
+}
+
+// UnidirectionalConnect implements BoundEndpoint.UnidirectionalConnect.
+func (e *connectionlessEndpoint) UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error) {
+ e.Lock()
+ r := e.receiver
+ e.Unlock()
+ if r == nil {
+ return nil, syserr.ErrConnectionRefused
+ }
+ q := r.(*queueReceiver).readQueue
+ if !q.TryIncRef() {
+ return nil, syserr.ErrConnectionRefused
+ }
+ return &connectedEndpoint{
+ endpoint: e,
+ writeQueue: q,
+ }, nil
+}
+
+// SendMsg writes data and a control message to the specified endpoint.
+// This method does not block if the data cannot be written.
+func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) {
+ if to == nil {
+ return e.baseEndpoint.SendMsg(ctx, data, c, nil)
+ }
+
+ connected, err := to.UnidirectionalConnect(ctx)
+ if err != nil {
+ return 0, syserr.ErrInvalidEndpointState
+ }
+ defer connected.Release()
+
+ e.Lock()
+ n, notify, err := connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
+ e.Unlock()
+
+ if notify {
+ connected.SendNotify()
+ }
+
+ return n, err
+}
+
+// Type implements Endpoint.Type.
+func (e *connectionlessEndpoint) Type() linux.SockType {
+ return linux.SOCK_DGRAM
+}
+
+// Connect attempts to connect directly to server.
+func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoint) *syserr.Error {
+ connected, err := server.UnidirectionalConnect(ctx)
+ if err != nil {
+ return err
+ }
+
+ e.Lock()
+ if e.connected != nil {
+ e.connected.Release()
+ }
+ e.connected = connected
+ e.Unlock()
+
+ return nil
+}
+
+// Listen starts listening on the connection.
+func (e *connectionlessEndpoint) Listen(int) *syserr.Error {
+ return syserr.ErrNotSupported
+}
+
+// Accept accepts a new connection.
+func (e *connectionlessEndpoint) Accept() (Endpoint, *syserr.Error) {
+ return nil, syserr.ErrNotSupported
+}
+
+// Bind binds the connection.
+//
+// For Unix endpoints, this _only sets the address associated with the socket_.
+// Work associated with sockets in the filesystem or finding those sockets must
+// be done by a higher level.
+//
+// Bind will fail only if the socket is connected, bound or the passed address
+// is invalid (the empty string).
+func (e *connectionlessEndpoint) Bind(addr tcpip.FullAddress, commit func() *syserr.Error) *syserr.Error {
+ e.Lock()
+ defer e.Unlock()
+ if e.isBound() {
+ return syserr.ErrAlreadyBound
+ }
+ if addr.Addr == "" {
+ // The empty string is not permitted.
+ return syserr.ErrBadLocalAddress
+ }
+ if commit != nil {
+ if err := commit(); err != nil {
+ return err
+ }
+ }
+
+ // Save the bound address.
+ e.path = string(addr.Addr)
+ return nil
+}
+
+// Readiness returns the current readiness of the endpoint. For example, if
+// waiter.EventIn is set, the endpoint is immediately readable.
+func (e *connectionlessEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask {
+ e.Lock()
+ defer e.Unlock()
+
+ ready := waiter.EventMask(0)
+ if mask&waiter.EventIn != 0 && e.receiver.Readable() {
+ ready |= waiter.EventIn
+ }
+
+ if e.Connected() {
+ if mask&waiter.EventOut != 0 && e.connected.Writable() {
+ ready |= waiter.EventOut
+ }
+ }
+
+ return ready
+}
+
+// State implements socket.Socket.State.
+func (e *connectionlessEndpoint) State() uint32 {
+ e.Lock()
+ defer e.Unlock()
+
+ switch {
+ case e.isBound():
+ return linux.SS_UNCONNECTED
+ case e.Connected():
+ return linux.SS_CONNECTING
+ default:
+ return linux.SS_DISCONNECTING
+ }
+}
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
new file mode 100644
index 000000000..d8f3ad63d
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -0,0 +1,247 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package transport
+
+import (
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// queue is a buffer queue.
+//
+// +stateify savable
+type queue struct {
+ refs.AtomicRefCount
+
+ ReaderQueue *waiter.Queue
+ WriterQueue *waiter.Queue
+
+ mu sync.Mutex `state:"nosave"`
+ closed bool
+ unread bool
+ used int64
+ limit int64
+ dataList messageList
+}
+
+// Close closes q for reading and writing. It is immediately not writable and
+// will become unreadable when no more data is pending.
+//
+// Both the read and write queues must be notified after closing:
+// q.ReaderQueue.Notify(waiter.EventIn)
+// q.WriterQueue.Notify(waiter.EventOut)
+func (q *queue) Close() {
+ q.mu.Lock()
+ q.closed = true
+ q.mu.Unlock()
+}
+
+// Reset empties the queue and Releases all of the Entries.
+//
+// Both the read and write queues must be notified after resetting:
+// q.ReaderQueue.Notify(waiter.EventIn)
+// q.WriterQueue.Notify(waiter.EventOut)
+func (q *queue) Reset() {
+ q.mu.Lock()
+ for cur := q.dataList.Front(); cur != nil; cur = cur.Next() {
+ cur.Release()
+ }
+ q.dataList.Reset()
+ q.used = 0
+ q.mu.Unlock()
+}
+
+// DecRef implements RefCounter.DecRef with destructor q.Reset.
+func (q *queue) DecRef() {
+ q.DecRefWithDestructor(q.Reset)
+ // We don't need to notify after resetting because no one cares about
+ // this queue after all references have been dropped.
+}
+
+// IsReadable determines if q is currently readable.
+func (q *queue) IsReadable() bool {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ return q.closed || q.dataList.Front() != nil
+}
+
+// bufWritable returns true if there is space for writing.
+//
+// N.B. Linux only considers a unix socket "writable" if >75% of the buffer is
+// free.
+//
+// See net/unix/af_unix.c:unix_writeable.
+func (q *queue) bufWritable() bool {
+ return 4*q.used < q.limit
+}
+
+// IsWritable determines if q is currently writable.
+func (q *queue) IsWritable() bool {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ return q.closed || q.bufWritable()
+}
+
+// Enqueue adds an entry to the data queue if room is available.
+//
+// If discardEmpty is true and there are zero bytes of data, the packet is
+// dropped.
+//
+// If truncate is true, Enqueue may truncate the message before enqueuing it.
+// Otherwise, the entire message must fit. If l is less than the size of data,
+// err indicates why.
+//
+// If notify is true, ReaderQueue.Notify must be called:
+// q.ReaderQueue.Notify(waiter.EventIn)
+func (q *queue) Enqueue(data [][]byte, c ControlMessages, from tcpip.FullAddress, discardEmpty bool, truncate bool) (l int64, notify bool, err *syserr.Error) {
+ q.mu.Lock()
+
+ if q.closed {
+ q.mu.Unlock()
+ return 0, false, syserr.ErrClosedForSend
+ }
+
+ for _, d := range data {
+ l += int64(len(d))
+ }
+ if discardEmpty && l == 0 {
+ q.mu.Unlock()
+ c.Release()
+ return 0, false, nil
+ }
+
+ free := q.limit - q.used
+
+ if l > free && truncate {
+ if free == 0 {
+ // Message can't fit right now.
+ q.mu.Unlock()
+ return 0, false, syserr.ErrWouldBlock
+ }
+
+ l = free
+ err = syserr.ErrWouldBlock
+ }
+
+ if l > q.limit {
+ // Message is too big to ever fit.
+ q.mu.Unlock()
+ return 0, false, syserr.ErrMessageTooLong
+ }
+
+ if l > free {
+ // Message can't fit right now, and could not be truncated.
+ q.mu.Unlock()
+ return 0, false, syserr.ErrWouldBlock
+ }
+
+ // Aggregate l bytes of data. This will truncate the data if l is less than
+ // the total bytes held in data.
+ v := make([]byte, l)
+ for i, b := 0, v; i < len(data) && len(b) > 0; i++ {
+ n := copy(b, data[i])
+ b = b[n:]
+ }
+
+ notify = q.dataList.Front() == nil
+ q.used += l
+ q.dataList.PushBack(&message{
+ Data: buffer.View(v),
+ Control: c,
+ Address: from,
+ })
+
+ q.mu.Unlock()
+
+ return l, notify, err
+}
+
+// Dequeue removes the first entry in the data queue, if one exists.
+//
+// If notify is true, WriterQueue.Notify must be called:
+// q.WriterQueue.Notify(waiter.EventOut)
+func (q *queue) Dequeue() (e *message, notify bool, err *syserr.Error) {
+ q.mu.Lock()
+
+ if q.dataList.Front() == nil {
+ err := syserr.ErrWouldBlock
+ if q.closed {
+ err = syserr.ErrClosedForReceive
+ if q.unread {
+ err = syserr.ErrConnectionReset
+ }
+ }
+ q.mu.Unlock()
+
+ return nil, false, err
+ }
+
+ notify = !q.bufWritable()
+
+ e = q.dataList.Front()
+ q.dataList.Remove(e)
+ q.used -= e.Length()
+
+ notify = notify && q.bufWritable()
+
+ q.mu.Unlock()
+
+ return e, notify, nil
+}
+
+// Peek returns the first entry in the data queue, if one exists.
+func (q *queue) Peek() (*message, *syserr.Error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ if q.dataList.Front() == nil {
+ err := syserr.ErrWouldBlock
+ if q.closed {
+ if err = syserr.ErrClosedForReceive; q.unread {
+ err = syserr.ErrConnectionReset
+ }
+ }
+ return nil, err
+ }
+
+ return q.dataList.Front().Peek(), nil
+}
+
+// QueuedSize returns the number of bytes currently in the queue, that is, the
+// number of readable bytes.
+func (q *queue) QueuedSize() int64 {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ return q.used
+}
+
+// MaxQueueSize returns the maximum number of bytes storable in the queue.
+func (q *queue) MaxQueueSize() int64 {
+ return q.limit
+}
+
+// CloseUnread sets flag to indicate that the peer is closed (not shutdown)
+// with unread data. So if read on this queue shall return ECONNRESET error.
+func (q *queue) CloseUnread() {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+ q.unread = true
+}
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
new file mode 100644
index 000000000..2f1b127df
--- /dev/null
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -0,0 +1,1006 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package transport contains the implementation of Unix endpoints.
+package transport
+
+import (
+ "sync/atomic"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/sync"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/buffer"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// initialLimit is the starting limit for the socket buffers.
+const initialLimit = 16 * 1024
+
+// A RightsControlMessage is a control message containing FDs.
+type RightsControlMessage interface {
+ // Clone returns a copy of the RightsControlMessage.
+ Clone() RightsControlMessage
+
+ // Release releases any resources owned by the RightsControlMessage.
+ Release()
+}
+
+// A CredentialsControlMessage is a control message containing Unix credentials.
+type CredentialsControlMessage interface {
+ // Equals returns true iff the two messages are equal.
+ Equals(CredentialsControlMessage) bool
+}
+
+// A ControlMessages represents a collection of socket control messages.
+//
+// +stateify savable
+type ControlMessages struct {
+ // Rights is a control message containing FDs.
+ Rights RightsControlMessage
+
+ // Credentials is a control message containing Unix credentials.
+ Credentials CredentialsControlMessage
+}
+
+// Empty returns true iff the ControlMessages does not contain either
+// credentials or rights.
+func (c *ControlMessages) Empty() bool {
+ return c.Rights == nil && c.Credentials == nil
+}
+
+// Clone clones both the credentials and the rights.
+func (c *ControlMessages) Clone() ControlMessages {
+ cm := ControlMessages{}
+ if c.Rights != nil {
+ cm.Rights = c.Rights.Clone()
+ }
+ cm.Credentials = c.Credentials
+ return cm
+}
+
+// Release releases both the credentials and the rights.
+func (c *ControlMessages) Release() {
+ if c.Rights != nil {
+ c.Rights.Release()
+ }
+ *c = ControlMessages{}
+}
+
+// Endpoint is the interface implemented by Unix transport protocol
+// implementations that expose functionality like sendmsg, recvmsg, connect,
+// etc. to Unix socket implementations.
+type Endpoint interface {
+ Credentialer
+ waiter.Waitable
+
+ // Close puts the endpoint in a closed state and frees all resources
+ // associated with it.
+ Close()
+
+ // RecvMsg reads data and a control message from the endpoint. This method
+ // does not block if there is no data pending.
+ //
+ // creds indicates if credential control messages are requested by the
+ // caller. This is useful for determining if control messages can be
+ // coalesced. creds is a hint and can be safely ignored by the
+ // implementation if no coalescing is possible. It is fine to return
+ // credential control messages when none were requested or to not return
+ // credential control messages when they were requested.
+ //
+ // numRights is the number of SCM_RIGHTS FDs requested by the caller. This
+ // is useful if one must allocate a buffer to receive a SCM_RIGHTS message
+ // or determine if control messages can be coalesced. numRights is a hint
+ // and can be safely ignored by the implementation if the number of
+ // available SCM_RIGHTS FDs is known and no coalescing is possible. It is
+ // fine for the returned number of SCM_RIGHTS FDs to be either higher or
+ // lower than the requested number.
+ //
+ // If peek is true, no data should be consumed from the Endpoint. Any and
+ // all data returned from a peek should be available in the next call to
+ // RecvMsg.
+ //
+ // recvLen is the number of bytes copied into data.
+ //
+ // msgLen is the length of the read message consumed for datagram Endpoints.
+ // msgLen is always the same as recvLen for stream Endpoints.
+ //
+ // CMTruncated indicates that the numRights hint was used to receive fewer
+ // than the total available SCM_RIGHTS FDs. Additional truncation may be
+ // required by the caller.
+ RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool, addr *tcpip.FullAddress) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, err *syserr.Error)
+
+ // SendMsg writes data and a control message to the endpoint's peer.
+ // This method does not block if the data cannot be written.
+ //
+ // SendMsg does not take ownership of any of its arguments on error.
+ SendMsg(context.Context, [][]byte, ControlMessages, BoundEndpoint) (int64, *syserr.Error)
+
+ // Connect connects this endpoint directly to another.
+ //
+ // This should be called on the client endpoint, and the (bound)
+ // endpoint passed in as a parameter.
+ //
+ // The error codes are the same as Connect.
+ Connect(ctx context.Context, server BoundEndpoint) *syserr.Error
+
+ // Shutdown closes the read and/or write end of the endpoint connection
+ // to its peer.
+ Shutdown(flags tcpip.ShutdownFlags) *syserr.Error
+
+ // Listen puts the endpoint in "listen" mode, which allows it to accept
+ // new connections.
+ Listen(backlog int) *syserr.Error
+
+ // Accept returns a new endpoint if a peer has established a connection
+ // to an endpoint previously set to listen mode. This method does not
+ // block if no new connections are available.
+ //
+ // The returned Queue is the wait queue for the newly created endpoint.
+ Accept() (Endpoint, *syserr.Error)
+
+ // Bind binds the endpoint to a specific local address and port.
+ // Specifying a NIC is optional.
+ //
+ // An optional commit function will be executed atomically with respect
+ // to binding the endpoint. If this returns an error, the bind will not
+ // occur and the error will be propagated back to the caller.
+ Bind(address tcpip.FullAddress, commit func() *syserr.Error) *syserr.Error
+
+ // Type return the socket type, typically either SockStream, SockDgram
+ // or SockSeqpacket.
+ Type() linux.SockType
+
+ // GetLocalAddress returns the address to which the endpoint is bound.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // GetRemoteAddress returns the address to which the endpoint is
+ // connected.
+ GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // SetSockOpt sets a socket option. opt should be one of the tcpip.*Option
+ // types.
+ SetSockOpt(opt interface{}) *tcpip.Error
+
+ // SetSockOptBool sets a socket option for simple cases when a value has
+ // the int type.
+ SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error
+
+ // SetSockOptInt sets a socket option for simple cases when a value has
+ // the int type.
+ SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
+
+ // GetSockOpt gets a socket option. opt should be a pointer to one of the
+ // tcpip.*Option types.
+ GetSockOpt(opt interface{}) *tcpip.Error
+
+ // GetSockOptBool gets a socket option for simple cases when a return
+ // value has the int type.
+ GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error)
+
+ // GetSockOptInt gets a socket option for simple cases when a return
+ // value has the int type.
+ GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
+
+ // State returns the current state of the socket, as represented by Linux in
+ // procfs.
+ State() uint32
+}
+
+// A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket
+// option.
+type Credentialer interface {
+ // Passcred returns whether or not the SO_PASSCRED socket option is
+ // enabled on this end.
+ Passcred() bool
+
+ // ConnectedPasscred returns whether or not the SO_PASSCRED socket option
+ // is enabled on the connected end.
+ ConnectedPasscred() bool
+}
+
+// A BoundEndpoint is a unix endpoint that can be connected to.
+type BoundEndpoint interface {
+ // BidirectionalConnect establishes a bi-directional connection between two
+ // unix endpoints in an all-or-nothing manner. If an error occurs during
+ // connecting, the state of neither endpoint should be modified.
+ //
+ // In order for an endpoint to establish such a bidirectional connection
+ // with a BoundEndpoint, the endpoint calls the BidirectionalConnect method
+ // on the BoundEndpoint and sends a representation of itself (the
+ // ConnectingEndpoint) and a callback (returnConnect) to receive the
+ // connection information (Receiver and ConnectedEndpoint) upon a
+ // successful connect. The callback should only be called on a successful
+ // connect.
+ //
+ // For a connection attempt to be successful, the ConnectingEndpoint must
+ // be unconnected and not listening and the BoundEndpoint whose
+ // BidirectionalConnect method is being called must be listening.
+ //
+ // This method will return syserr.ErrConnectionRefused on endpoints with a
+ // type that isn't SockStream or SockSeqpacket.
+ BidirectionalConnect(ctx context.Context, ep ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error
+
+ // UnidirectionalConnect establishes a write-only connection to a unix
+ // endpoint.
+ //
+ // An endpoint which calls UnidirectionalConnect and supports it itself must
+ // not hold its own lock when calling UnidirectionalConnect.
+ //
+ // This method will return syserr.ErrConnectionRefused on a non-SockDgram
+ // endpoint.
+ UnidirectionalConnect(ctx context.Context) (ConnectedEndpoint, *syserr.Error)
+
+ // Passcred returns whether or not the SO_PASSCRED socket option is
+ // enabled on this end.
+ Passcred() bool
+
+ // Release releases any resources held by the BoundEndpoint. It must be
+ // called before dropping all references to a BoundEndpoint returned by a
+ // function.
+ Release()
+}
+
+// message represents a message passed over a Unix domain socket.
+//
+// +stateify savable
+type message struct {
+ messageEntry
+
+ // Data is the Message payload.
+ Data buffer.View
+
+ // Control is auxiliary control message data that goes along with the
+ // data.
+ Control ControlMessages
+
+ // Address is the bound address of the endpoint that sent the message.
+ //
+ // If the endpoint that sent the message is not bound, the Address is
+ // the empty string.
+ Address tcpip.FullAddress
+}
+
+// Length returns number of bytes stored in the message.
+func (m *message) Length() int64 {
+ return int64(len(m.Data))
+}
+
+// Release releases any resources held by the message.
+func (m *message) Release() {
+ m.Control.Release()
+}
+
+// Peek returns a copy of the message.
+func (m *message) Peek() *message {
+ return &message{Data: m.Data, Control: m.Control.Clone(), Address: m.Address}
+}
+
+// Truncate reduces the length of the message payload to n bytes.
+//
+// Preconditions: n <= m.Length().
+func (m *message) Truncate(n int64) {
+ m.Data.CapLength(int(n))
+}
+
+// A Receiver can be used to receive Messages.
+type Receiver interface {
+ // Recv receives a single message. This method does not block.
+ //
+ // See Endpoint.RecvMsg for documentation on shared arguments.
+ //
+ // notify indicates if RecvNotify should be called.
+ Recv(data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error)
+
+ // RecvNotify notifies the Receiver of a successful Recv. This must not be
+ // called while holding any endpoint locks.
+ RecvNotify()
+
+ // CloseRecv prevents the receiving of additional Messages.
+ //
+ // After CloseRecv is called, CloseNotify must also be called.
+ CloseRecv()
+
+ // CloseNotify notifies the Receiver of recv being closed. This must not be
+ // called while holding any endpoint locks.
+ CloseNotify()
+
+ // Readable returns if messages should be attempted to be received. This
+ // includes when read has been shutdown.
+ Readable() bool
+
+ // RecvQueuedSize returns the total amount of data currently receivable.
+ // RecvQueuedSize should return -1 if the operation isn't supported.
+ RecvQueuedSize() int64
+
+ // RecvMaxQueueSize returns maximum value for RecvQueuedSize.
+ // RecvMaxQueueSize should return -1 if the operation isn't supported.
+ RecvMaxQueueSize() int64
+
+ // Release releases any resources owned by the Receiver. It should be
+ // called before droping all references to a Receiver.
+ Release()
+}
+
+// queueReceiver implements Receiver for datagram sockets.
+//
+// +stateify savable
+type queueReceiver struct {
+ readQueue *queue
+}
+
+// Recv implements Receiver.Recv.
+func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+ var m *message
+ var notify bool
+ var err *syserr.Error
+ if peek {
+ m, err = q.readQueue.Peek()
+ } else {
+ m, notify, err = q.readQueue.Dequeue()
+ }
+ if err != nil {
+ return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err
+ }
+ src := []byte(m.Data)
+ var copied int64
+ for i := 0; i < len(data) && len(src) > 0; i++ {
+ n := copy(data[i], src)
+ copied += int64(n)
+ src = src[n:]
+ }
+ return copied, int64(len(m.Data)), m.Control, false, m.Address, notify, nil
+}
+
+// RecvNotify implements Receiver.RecvNotify.
+func (q *queueReceiver) RecvNotify() {
+ q.readQueue.WriterQueue.Notify(waiter.EventOut)
+}
+
+// CloseNotify implements Receiver.CloseNotify.
+func (q *queueReceiver) CloseNotify() {
+ q.readQueue.ReaderQueue.Notify(waiter.EventIn)
+ q.readQueue.WriterQueue.Notify(waiter.EventOut)
+}
+
+// CloseRecv implements Receiver.CloseRecv.
+func (q *queueReceiver) CloseRecv() {
+ q.readQueue.Close()
+}
+
+// Readable implements Receiver.Readable.
+func (q *queueReceiver) Readable() bool {
+ return q.readQueue.IsReadable()
+}
+
+// RecvQueuedSize implements Receiver.RecvQueuedSize.
+func (q *queueReceiver) RecvQueuedSize() int64 {
+ return q.readQueue.QueuedSize()
+}
+
+// RecvMaxQueueSize implements Receiver.RecvMaxQueueSize.
+func (q *queueReceiver) RecvMaxQueueSize() int64 {
+ return q.readQueue.MaxQueueSize()
+}
+
+// Release implements Receiver.Release.
+func (q *queueReceiver) Release() {
+ q.readQueue.DecRef()
+}
+
+// streamQueueReceiver implements Receiver for stream sockets.
+//
+// +stateify savable
+type streamQueueReceiver struct {
+ queueReceiver
+
+ mu sync.Mutex `state:"nosave"`
+ buffer []byte
+ control ControlMessages
+ addr tcpip.FullAddress
+}
+
+func vecCopy(data [][]byte, buf []byte) (int64, [][]byte, []byte) {
+ var copied int64
+ for len(data) > 0 && len(buf) > 0 {
+ n := copy(data[0], buf)
+ copied += int64(n)
+ buf = buf[n:]
+ data[0] = data[0][n:]
+ if len(data[0]) == 0 {
+ data = data[1:]
+ }
+ }
+ return copied, data, buf
+}
+
+// Readable implements Receiver.Readable.
+func (q *streamQueueReceiver) Readable() bool {
+ q.mu.Lock()
+ bl := len(q.buffer)
+ r := q.readQueue.IsReadable()
+ q.mu.Unlock()
+ // We're readable if we have data in our buffer or if the queue receiver is
+ // readable.
+ return bl > 0 || r
+}
+
+// RecvQueuedSize implements Receiver.RecvQueuedSize.
+func (q *streamQueueReceiver) RecvQueuedSize() int64 {
+ q.mu.Lock()
+ bl := len(q.buffer)
+ qs := q.readQueue.QueuedSize()
+ q.mu.Unlock()
+ return int64(bl) + qs
+}
+
+// RecvMaxQueueSize implements Receiver.RecvMaxQueueSize.
+func (q *streamQueueReceiver) RecvMaxQueueSize() int64 {
+ // The RecvMaxQueueSize() is the readQueue's MaxQueueSize() plus the largest
+ // message we can buffer which is also the largest message we can receive.
+ return 2 * q.readQueue.MaxQueueSize()
+}
+
+// Recv implements Receiver.Recv.
+func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ var notify bool
+
+ // If we have no data in the endpoint, we need to get some.
+ if len(q.buffer) == 0 {
+ // Load the next message into a buffer, even if we are peeking. Peeking
+ // won't consume the message, so it will be still available to be read
+ // the next time Recv() is called.
+ m, n, err := q.readQueue.Dequeue()
+ if err != nil {
+ return 0, 0, ControlMessages{}, false, tcpip.FullAddress{}, false, err
+ }
+ notify = n
+ q.buffer = []byte(m.Data)
+ q.control = m.Control
+ q.addr = m.Address
+ }
+
+ var copied int64
+ if peek {
+ // Don't consume control message if we are peeking.
+ c := q.control.Clone()
+
+ // Don't consume data since we are peeking.
+ copied, data, _ = vecCopy(data, q.buffer)
+
+ return copied, copied, c, false, q.addr, notify, nil
+ }
+
+ // Consume data and control message since we are not peeking.
+ copied, data, q.buffer = vecCopy(data, q.buffer)
+
+ // Save the original state of q.control.
+ c := q.control
+
+ // Remove rights from q.control and leave behind just the creds.
+ q.control.Rights = nil
+ if !wantCreds {
+ c.Credentials = nil
+ }
+
+ var cmTruncated bool
+ if c.Rights != nil && numRights == 0 {
+ c.Rights.Release()
+ c.Rights = nil
+ cmTruncated = true
+ }
+
+ haveRights := c.Rights != nil
+
+ // If we have more capacity for data and haven't received any usable
+ // rights.
+ //
+ // Linux never coalesces rights control messages.
+ for !haveRights && len(data) > 0 {
+ // Get a message from the readQueue.
+ m, n, err := q.readQueue.Dequeue()
+ if err != nil {
+ // We already got some data, so ignore this error. This will
+ // manifest as a short read to the user, which is what Linux
+ // does.
+ break
+ }
+ notify = notify || n
+ q.buffer = []byte(m.Data)
+ q.control = m.Control
+ q.addr = m.Address
+
+ if wantCreds {
+ if (q.control.Credentials == nil) != (c.Credentials == nil) {
+ // One message has credentials, the other does not.
+ break
+ }
+
+ if q.control.Credentials != nil && c.Credentials != nil && !q.control.Credentials.Equals(c.Credentials) {
+ // Both messages have credentials, but they don't match.
+ break
+ }
+ }
+
+ if numRights != 0 && c.Rights != nil && q.control.Rights != nil {
+ // Both messages have rights.
+ break
+ }
+
+ var cpd int64
+ cpd, data, q.buffer = vecCopy(data, q.buffer)
+ copied += cpd
+
+ if cpd == 0 {
+ // data was actually full.
+ break
+ }
+
+ if q.control.Rights != nil {
+ // Consume rights.
+ if numRights == 0 {
+ cmTruncated = true
+ q.control.Rights.Release()
+ } else {
+ c.Rights = q.control.Rights
+ haveRights = true
+ }
+ q.control.Rights = nil
+ }
+ }
+ return copied, copied, c, cmTruncated, q.addr, notify, nil
+}
+
+// A ConnectedEndpoint is an Endpoint that can be used to send Messages.
+type ConnectedEndpoint interface {
+ // Passcred implements Endpoint.Passcred.
+ Passcred() bool
+
+ // GetLocalAddress implements Endpoint.GetLocalAddress.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Send sends a single message. This method does not block.
+ //
+ // notify indicates if SendNotify should be called.
+ //
+ // syserr.ErrWouldBlock can be returned along with a partial write if
+ // the caller should block to send the rest of the data.
+ Send(data [][]byte, c ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error)
+
+ // SendNotify notifies the ConnectedEndpoint of a successful Send. This
+ // must not be called while holding any endpoint locks.
+ SendNotify()
+
+ // CloseSend prevents the sending of additional Messages.
+ //
+ // After CloseSend is call, CloseNotify must also be called.
+ CloseSend()
+
+ // CloseNotify notifies the ConnectedEndpoint of send being closed. This
+ // must not be called while holding any endpoint locks.
+ CloseNotify()
+
+ // Writable returns if messages should be attempted to be sent. This
+ // includes when write has been shutdown.
+ Writable() bool
+
+ // EventUpdate lets the ConnectedEndpoint know that event registrations
+ // have changed.
+ EventUpdate()
+
+ // SendQueuedSize returns the total amount of data currently queued for
+ // sending. SendQueuedSize should return -1 if the operation isn't
+ // supported.
+ SendQueuedSize() int64
+
+ // SendMaxQueueSize returns maximum value for SendQueuedSize.
+ // SendMaxQueueSize should return -1 if the operation isn't supported.
+ SendMaxQueueSize() int64
+
+ // Release releases any resources owned by the ConnectedEndpoint. It should
+ // be called before droping all references to a ConnectedEndpoint.
+ Release()
+
+ // CloseUnread sets the fact that this end is closed with unread data to
+ // the peer socket.
+ CloseUnread()
+}
+
+// +stateify savable
+type connectedEndpoint struct {
+ // endpoint represents the subset of the Endpoint functionality needed by
+ // the connectedEndpoint. It is implemented by both connectionedEndpoint
+ // and connectionlessEndpoint and allows the use of types which don't
+ // fully implement Endpoint.
+ endpoint interface {
+ // Passcred implements Endpoint.Passcred.
+ Passcred() bool
+
+ // GetLocalAddress implements Endpoint.GetLocalAddress.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Type implements Endpoint.Type.
+ Type() linux.SockType
+ }
+
+ writeQueue *queue
+}
+
+// Passcred implements ConnectedEndpoint.Passcred.
+func (e *connectedEndpoint) Passcred() bool {
+ return e.endpoint.Passcred()
+}
+
+// GetLocalAddress implements ConnectedEndpoint.GetLocalAddress.
+func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ return e.endpoint.GetLocalAddress()
+}
+
+// Send implements ConnectedEndpoint.Send.
+func (e *connectedEndpoint) Send(data [][]byte, c ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) {
+ discardEmpty := false
+ truncate := false
+ if e.endpoint.Type() == linux.SOCK_STREAM {
+ // Discard empty stream packets. Since stream sockets don't
+ // preserve message boundaries, sending zero bytes is a no-op.
+ // In Linux, the receiver actually uses a zero-length receive
+ // as an indication that the stream was closed.
+ discardEmpty = true
+
+ // Since stream sockets don't preserve message boundaries, we
+ // can write only as much of the message as fits in the queue.
+ truncate = true
+ }
+
+ return e.writeQueue.Enqueue(data, c, from, discardEmpty, truncate)
+}
+
+// SendNotify implements ConnectedEndpoint.SendNotify.
+func (e *connectedEndpoint) SendNotify() {
+ e.writeQueue.ReaderQueue.Notify(waiter.EventIn)
+}
+
+// CloseNotify implements ConnectedEndpoint.CloseNotify.
+func (e *connectedEndpoint) CloseNotify() {
+ e.writeQueue.ReaderQueue.Notify(waiter.EventIn)
+ e.writeQueue.WriterQueue.Notify(waiter.EventOut)
+}
+
+// CloseSend implements ConnectedEndpoint.CloseSend.
+func (e *connectedEndpoint) CloseSend() {
+ e.writeQueue.Close()
+}
+
+// Writable implements ConnectedEndpoint.Writable.
+func (e *connectedEndpoint) Writable() bool {
+ return e.writeQueue.IsWritable()
+}
+
+// EventUpdate implements ConnectedEndpoint.EventUpdate.
+func (*connectedEndpoint) EventUpdate() {}
+
+// SendQueuedSize implements ConnectedEndpoint.SendQueuedSize.
+func (e *connectedEndpoint) SendQueuedSize() int64 {
+ return e.writeQueue.QueuedSize()
+}
+
+// SendMaxQueueSize implements ConnectedEndpoint.SendMaxQueueSize.
+func (e *connectedEndpoint) SendMaxQueueSize() int64 {
+ return e.writeQueue.MaxQueueSize()
+}
+
+// Release implements ConnectedEndpoint.Release.
+func (e *connectedEndpoint) Release() {
+ e.writeQueue.DecRef()
+}
+
+// CloseUnread implements ConnectedEndpoint.CloseUnread.
+func (e *connectedEndpoint) CloseUnread() {
+ e.writeQueue.CloseUnread()
+}
+
+// baseEndpoint is an embeddable unix endpoint base used in both the connected and connectionless
+// unix domain socket Endpoint implementations.
+//
+// Not to be used on its own.
+//
+// +stateify savable
+type baseEndpoint struct {
+ *waiter.Queue
+
+ // passcred specifies whether SCM_CREDENTIALS socket control messages are
+ // enabled on this endpoint. Must be accessed atomically.
+ passcred int32
+
+ // Mutex protects the below fields.
+ sync.Mutex `state:"nosave"`
+
+ // receiver allows Messages to be received.
+ receiver Receiver
+
+ // connected allows messages to be sent and state information about the
+ // connected endpoint to be read.
+ connected ConnectedEndpoint
+
+ // path is not empty if the endpoint has been bound,
+ // or may be used if the endpoint is connected.
+ path string
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (e *baseEndpoint) EventRegister(we *waiter.Entry, mask waiter.EventMask) {
+ e.Queue.EventRegister(we, mask)
+ e.Lock()
+ if e.connected != nil {
+ e.connected.EventUpdate()
+ }
+ e.Unlock()
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (e *baseEndpoint) EventUnregister(we *waiter.Entry) {
+ e.Queue.EventUnregister(we)
+ e.Lock()
+ if e.connected != nil {
+ e.connected.EventUpdate()
+ }
+ e.Unlock()
+}
+
+// Passcred implements Credentialer.Passcred.
+func (e *baseEndpoint) Passcred() bool {
+ return atomic.LoadInt32(&e.passcred) != 0
+}
+
+// ConnectedPasscred implements Credentialer.ConnectedPasscred.
+func (e *baseEndpoint) ConnectedPasscred() bool {
+ e.Lock()
+ defer e.Unlock()
+ return e.connected != nil && e.connected.Passcred()
+}
+
+func (e *baseEndpoint) setPasscred(pc bool) {
+ if pc {
+ atomic.StoreInt32(&e.passcred, 1)
+ } else {
+ atomic.StoreInt32(&e.passcred, 0)
+ }
+}
+
+// Connected implements ConnectingEndpoint.Connected.
+func (e *baseEndpoint) Connected() bool {
+ return e.receiver != nil && e.connected != nil
+}
+
+// RecvMsg reads data and a control message from the endpoint.
+func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool, addr *tcpip.FullAddress) (int64, int64, ControlMessages, bool, *syserr.Error) {
+ e.Lock()
+
+ if e.receiver == nil {
+ e.Unlock()
+ return 0, 0, ControlMessages{}, false, syserr.ErrNotConnected
+ }
+
+ recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(data, creds, numRights, peek)
+ e.Unlock()
+ if err != nil {
+ return 0, 0, ControlMessages{}, false, err
+ }
+
+ if notify {
+ e.receiver.RecvNotify()
+ }
+
+ if addr != nil {
+ *addr = a
+ }
+ return recvLen, msgLen, cms, cmt, nil
+}
+
+// SendMsg writes data and a control message to the endpoint's peer.
+// This method does not block if the data cannot be written.
+func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMessages, to BoundEndpoint) (int64, *syserr.Error) {
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return 0, syserr.ErrNotConnected
+ }
+ if to != nil {
+ e.Unlock()
+ return 0, syserr.ErrAlreadyConnected
+ }
+
+ n, notify, err := e.connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)})
+ e.Unlock()
+
+ if notify {
+ e.connected.SendNotify()
+ }
+
+ return n, err
+}
+
+// SetSockOpt sets a socket option. Currently not supported.
+func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+ return nil
+}
+
+func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.BroadcastOption:
+ case tcpip.PasscredOption:
+ e.setPasscred(v)
+ case tcpip.ReuseAddressOption:
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ }
+ return nil
+}
+
+func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ }
+ return nil
+}
+
+func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
+ switch opt {
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ case tcpip.PasscredOption:
+ return e.Passcred(), nil
+
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ return false, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
+ switch opt {
+ case tcpip.ReceiveQueueSizeOption:
+ v := 0
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return -1, tcpip.ErrNotConnected
+ }
+ v = int(e.receiver.RecvQueuedSize())
+ e.Unlock()
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
+ }
+ return v, nil
+
+ case tcpip.SendQueueSizeOption:
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return -1, tcpip.ErrNotConnected
+ }
+ v := e.connected.SendQueuedSize()
+ e.Unlock()
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
+ }
+ return int(v), nil
+
+ case tcpip.SendBufferSizeOption:
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return -1, tcpip.ErrNotConnected
+ }
+ v := e.connected.SendMaxQueueSize()
+ e.Unlock()
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
+ }
+ return int(v), nil
+
+ case tcpip.ReceiveBufferSizeOption:
+ e.Lock()
+ if e.receiver == nil {
+ e.Unlock()
+ return -1, tcpip.ErrNotConnected
+ }
+ v := e.receiver.RecvMaxQueueSize()
+ e.Unlock()
+ if v < 0 {
+ return -1, tcpip.ErrQueueSizeNotSupported
+ }
+ return int(v), nil
+
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ return -1, tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
+func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
+ switch opt.(type) {
+ case tcpip.ErrorOption:
+ return nil
+
+ default:
+ log.Warningf("Unsupported socket option: %T", opt)
+ return tcpip.ErrUnknownProtocolOption
+ }
+}
+
+// Shutdown closes the read and/or write end of the endpoint connection to its
+// peer.
+func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *syserr.Error {
+ e.Lock()
+ if !e.Connected() {
+ e.Unlock()
+ return syserr.ErrNotConnected
+ }
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.receiver.CloseRecv()
+ }
+
+ if flags&tcpip.ShutdownWrite != 0 {
+ e.connected.CloseSend()
+ }
+
+ e.Unlock()
+
+ if flags&tcpip.ShutdownRead != 0 {
+ e.receiver.CloseNotify()
+ }
+
+ if flags&tcpip.ShutdownWrite != 0 {
+ e.connected.CloseNotify()
+ }
+
+ return nil
+}
+
+// GetLocalAddress returns the bound path.
+func (e *baseEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.Lock()
+ defer e.Unlock()
+ return tcpip.FullAddress{Addr: tcpip.Address(e.path)}, nil
+}
+
+// GetRemoteAddress returns the local address of the connected endpoint (if
+// available).
+func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) {
+ e.Lock()
+ c := e.connected
+ e.Unlock()
+ if c != nil {
+ return c.GetLocalAddress()
+ }
+ return tcpip.FullAddress{}, tcpip.ErrNotConnected
+}
+
+// Release implements BoundEndpoint.Release.
+func (*baseEndpoint) Release() {
+ // Binding a baseEndpoint doesn't take a reference.
+}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
new file mode 100644
index 000000000..4bb2b6ff4
--- /dev/null
+++ b/pkg/sentry/socket/unix/unix.go
@@ -0,0 +1,772 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package unix provides an implementation of the socket.Socket interface for
+// the AF_UNIX protocol family.
+package unix
+
+import (
+ "fmt"
+ "strings"
+ "syscall"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/refs"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SocketOperations is a Unix socket. It is similar to a netstack socket,
+// except it is backed by a transport.Endpoint instead of a tcpip.Endpoint.
+//
+// +stateify savable
+type SocketOperations struct {
+ fsutil.FilePipeSeek `state:"nosave"`
+ fsutil.FileNotDirReaddir `state:"nosave"`
+ fsutil.FileNoFsync `state:"nosave"`
+ fsutil.FileNoMMap `state:"nosave"`
+ fsutil.FileNoSplice `state:"nosave"`
+ fsutil.FileNoopFlush `state:"nosave"`
+ fsutil.FileUseInodeUnstableAttr `state:"nosave"`
+
+ socketOpsCommon
+}
+
+// New creates a new unix socket.
+func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File {
+ dirent := socket.NewDirent(ctx, unixSocketDevice)
+ defer dirent.DecRef()
+ return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true, NonSeekable: true})
+}
+
+// NewWithDirent creates a new unix socket using an existing dirent.
+func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, stype linux.SockType, flags fs.FileFlags) *fs.File {
+ // You can create AF_UNIX, SOCK_RAW sockets. They're the same as
+ // SOCK_DGRAM and don't require CAP_NET_RAW.
+ if stype == linux.SOCK_RAW {
+ stype = linux.SOCK_DGRAM
+ }
+
+ s := SocketOperations{
+ socketOpsCommon: socketOpsCommon{
+ ep: ep,
+ stype: stype,
+ },
+ }
+ s.EnableLeakCheck("unix.SocketOperations")
+
+ return fs.NewFile(ctx, d, flags, &s)
+}
+
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
+ refs.AtomicRefCount
+ socket.SendReceiveTimeout
+
+ ep transport.Endpoint
+ stype linux.SockType
+}
+
+// DecRef implements RefCounter.DecRef.
+func (s *socketOpsCommon) DecRef() {
+ s.DecRefWithDestructor(func() {
+ s.ep.Close()
+ })
+}
+
+// Release implemements fs.FileOperations.Release.
+func (s *socketOpsCommon) Release() {
+ // Release only decrements a reference on s because s may be referenced in
+ // the abstract socket namespace.
+ s.DecRef()
+}
+
+func (s *socketOpsCommon) isPacket() bool {
+ switch s.stype {
+ case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ return true
+ case linux.SOCK_STREAM:
+ return false
+ default:
+ // We shouldn't have allowed any other socket types during creation.
+ panic(fmt.Sprintf("Invalid socket type %d", s.stype))
+ }
+}
+
+// Endpoint extracts the transport.Endpoint.
+func (s *socketOpsCommon) Endpoint() transport.Endpoint {
+ return s.ep
+}
+
+// extractPath extracts and validates the address.
+func extractPath(sockaddr []byte) (string, *syserr.Error) {
+ addr, family, err := netstack.AddressAndFamily(sockaddr)
+ if err != nil {
+ if err == syserr.ErrAddressFamilyNotSupported {
+ err = syserr.ErrInvalidArgument
+ }
+ return "", err
+ }
+ if family != linux.AF_UNIX {
+ return "", syserr.ErrInvalidArgument
+ }
+
+ // The address is trimmed by GetAddress.
+ p := string(addr.Addr)
+ if p == "" {
+ // Not allowed.
+ return "", syserr.ErrInvalidArgument
+ }
+ if p[len(p)-1] == '/' {
+ // Weird, they tried to bind '/a/b/c/'?
+ return "", syserr.ErrIsDir
+ }
+
+ return p, nil
+}
+
+// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
+// a transport.Endpoint.
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ addr, err := s.ep.GetRemoteAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
+ return a, l, nil
+}
+
+// GetSockName implements the linux syscall getsockname(2) for sockets backed by
+// a transport.Endpoint.
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+ addr, err := s.ep.GetLocalAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := netstack.ConvertAddress(linux.AF_UNIX, addr)
+ return a, l, nil
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return netstack.Ioctl(ctx, s.ep, io, args)
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+ return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
+}
+
+// Listen implements the linux syscall listen(2) for sockets backed by
+// a transport.Endpoint.
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ return s.ep.Listen(backlog)
+}
+
+// blockingAccept implements a blocking version of accept(2), that is, if no
+// connections are ready to be accept, it will block until one becomes ready.
+func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ // Try to accept the connection; if it fails, then wait until we get a
+ // notification.
+ for {
+ if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ return ep, err
+ }
+
+ if err := t.Block(ch); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ }
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, err := s.ep.Accept()
+ if err != nil {
+ if err != syserr.ErrWouldBlock || !blocking {
+ return 0, nil, 0, err
+ }
+
+ var err *syserr.Error
+ ep, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ ns := New(t, ep, s.stype)
+ defer ns.DecRef()
+
+ if flags&linux.SOCK_NONBLOCK != 0 {
+ flags := ns.Flags()
+ flags.NonBlocking = true
+ ns.SetFlags(flags.Settable())
+ }
+
+ var addr linux.SockAddr
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer.
+ var err *syserr.Error
+ addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ })
+ if e != nil {
+ return 0, nil, 0, syserr.FromError(e)
+ }
+
+ t.Kernel().RecordSocket(ns)
+
+ return fd, addr, addrLen, nil
+}
+
+// Bind implements the linux syscall bind(2) for unix sockets.
+func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ p, e := extractPath(sockaddr)
+ if e != nil {
+ return e
+ }
+
+ bep, ok := s.ep.(transport.BoundEndpoint)
+ if !ok {
+ // This socket can't be bound.
+ return syserr.ErrInvalidArgument
+ }
+
+ return s.ep.Bind(tcpip.FullAddress{Addr: tcpip.Address(p)}, func() *syserr.Error {
+ // Is it abstract?
+ if p[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return syserr.ErrInvalidEndpointState
+ }
+ if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil {
+ // syserr.ErrPortInUse corresponds to EADDRINUSE.
+ return syserr.ErrPortInUse
+ }
+ } else {
+ // The parent and name.
+ var d *fs.Dirent
+ var name string
+
+ cwd := t.FSContext().WorkingDirectory()
+ defer cwd.DecRef()
+
+ // Is there no slash at all?
+ if !strings.Contains(p, "/") {
+ d = cwd
+ name = p
+ } else {
+ root := t.FSContext().RootDirectory()
+ defer root.DecRef()
+ // Find the last path component, we know that something follows
+ // that final slash, otherwise extractPath() would have failed.
+ lastSlash := strings.LastIndex(p, "/")
+ subPath := p[:lastSlash]
+ if subPath == "" {
+ // Fix up subpath in case file is in root.
+ subPath = "/"
+ }
+ var err error
+ remainingTraversals := uint(fs.DefaultTraversalLimit)
+ d, err = t.MountNamespace().FindInode(t, root, cwd, subPath, &remainingTraversals)
+ if err != nil {
+ // No path available.
+ return syserr.ErrNoSuchFile
+ }
+ defer d.DecRef()
+ name = p[lastSlash+1:]
+ }
+
+ // Create the socket.
+ //
+ // Note that the file permissions here are not set correctly (see
+ // gvisor.dev/issue/2324). There is no convenient way to get permissions
+ // on the socket referred to by s, so we will leave this discrepancy
+ // unresolved until VFS2 replaces this code.
+ childDir, err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}})
+ if err != nil {
+ return syserr.ErrPortInUse
+ }
+ childDir.DecRef()
+ }
+
+ return nil
+ })
+}
+
+// extractEndpoint retrieves the transport.BoundEndpoint associated with a Unix
+// socket path. The Release must be called on the transport.BoundEndpoint when
+// the caller is done with it.
+func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, *syserr.Error) {
+ path, err := extractPath(sockaddr)
+ if err != nil {
+ return nil, err
+ }
+
+ // Is it abstract?
+ if path[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ ep := t.AbstractSockets().BoundEndpoint(path[1:])
+ if ep == nil {
+ // No socket found.
+ return nil, syserr.ErrConnectionRefused
+ }
+
+ return ep, nil
+ }
+
+ if kernel.VFS2Enabled {
+ p := fspath.Parse(path)
+ root := t.FSContext().RootDirectoryVFS2()
+ start := root
+ relPath := !p.Absolute
+ if relPath {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ }
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: p,
+ FollowFinalSymlink: true,
+ }
+ ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop, &vfs.BoundEndpointOptions{path})
+ root.DecRef()
+ if relPath {
+ start.DecRef()
+ }
+ if e != nil {
+ return nil, syserr.FromError(e)
+ }
+ return ep, nil
+ }
+
+ // Find the node in the filesystem.
+ root := t.FSContext().RootDirectory()
+ cwd := t.FSContext().WorkingDirectory()
+ remainingTraversals := uint(fs.DefaultTraversalLimit)
+ d, e := t.MountNamespace().FindInode(t, root, cwd, path, &remainingTraversals)
+ cwd.DecRef()
+ root.DecRef()
+ if e != nil {
+ return nil, syserr.FromError(e)
+ }
+
+ // Extract the endpoint if one is there.
+ ep := d.Inode.BoundEndpoint(path)
+ d.DecRef()
+ if ep == nil {
+ // No socket!
+ return nil, syserr.ErrConnectionRefused
+ }
+ return ep, nil
+}
+
+// Connect implements the linux syscall connect(2) for unix sockets.
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ ep, err := extractEndpoint(t, sockaddr)
+ if err != nil {
+ return err
+ }
+ defer ep.Release()
+
+ // Connect the server endpoint.
+ err = s.ep.Connect(t, ep)
+
+ if err == syserr.ErrWrongProtocolForSocket {
+ // Linux for abstract sockets returns ErrConnectionRefused
+ // instead of ErrWrongProtocolForSocket.
+ path, _ := extractPath(sockaddr)
+ if len(path) > 0 && path[0] == 0 {
+ err = syserr.ErrConnectionRefused
+ }
+ }
+
+ return err
+}
+
+// Write implements fs.FileOperations.Write.
+func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ t := kernel.TaskFromContext(ctx)
+ ctrl := control.New(t, s.ep, nil)
+
+ if src.NumBytes() == 0 {
+ nInt, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil)
+ return int64(nInt), err.ToError()
+ }
+
+ return src.CopyInTo(ctx, &EndpointWriter{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ Control: ctrl,
+ To: nil,
+ })
+}
+
+// SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by
+// a transport.Endpoint.
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+ w := EndpointWriter{
+ Ctx: t,
+ Endpoint: s.ep,
+ Control: controlMessages.Unix,
+ To: nil,
+ }
+ if len(to) > 0 {
+ switch s.stype {
+ case linux.SOCK_SEQPACKET:
+ to = nil
+ case linux.SOCK_STREAM:
+ if s.State() == linux.SS_CONNECTED {
+ return 0, syserr.ErrAlreadyConnected
+ }
+ return 0, syserr.ErrNotSupported
+ default:
+ ep, err := extractEndpoint(t, to)
+ if err != nil {
+ return 0, err
+ }
+ defer ep.Release()
+ w.To = ep
+
+ if ep.Passcred() && w.Control.Credentials == nil {
+ w.Control.Credentials = control.MakeCreds(t)
+ }
+ }
+ }
+
+ n, err := src.CopyInTo(t, &w)
+ if err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ return int(n), syserr.FromError(err)
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+
+ total := n
+ for {
+ // Shorten src to reflect bytes previously written.
+ src = src.DropFirst64(n)
+
+ n, err = src.CopyInTo(t, &w)
+ total += n
+ if err != syserror.ErrWouldBlock {
+ break
+ }
+
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ err = syserror.ErrWouldBlock
+ }
+ break
+ }
+ }
+
+ return int(total), syserr.FromError(err)
+}
+
+// Passcred implements transport.Credentialer.Passcred.
+func (s *socketOpsCommon) Passcred() bool {
+ return s.ep.Passcred()
+}
+
+// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
+func (s *socketOpsCommon) ConnectedPasscred() bool {
+ return s.ep.ConnectedPasscred()
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.ep.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.ep.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
+ s.ep.EventUnregister(e)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ return netstack.SetSockOpt(t, s, s.ep, level, name, optVal)
+}
+
+// Shutdown implements the linux syscall shutdown(2) for sockets backed by
+// a transport.Endpoint.
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ f, err := netstack.ConvertShutdown(how)
+ if err != nil {
+ return err
+ }
+
+ // Issue shutdown request.
+ return s.ep.Shutdown(f)
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &EndpointReader{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ NumRights: 0,
+ Peek: false,
+ From: nil,
+ })
+}
+
+// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
+// a transport.Endpoint.
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+ trunc := flags&linux.MSG_TRUNC != 0
+ peek := flags&linux.MSG_PEEK != 0
+ dontWait := flags&linux.MSG_DONTWAIT != 0
+ waitAll := flags&linux.MSG_WAITALL != 0
+ isPacket := s.isPacket()
+
+ // Calculate the number of FDs for which we have space and if we are
+ // requesting credentials.
+ var wantCreds bool
+ rightsLen := int(controlDataLen) - syscall.SizeofCmsghdr
+ if s.Passcred() {
+ // Credentials take priority if they are enabled and there is space.
+ wantCreds = rightsLen > 0
+ if !wantCreds {
+ msgFlags |= linux.MSG_CTRUNC
+ }
+ credLen := syscall.CmsgSpace(syscall.SizeofUcred)
+ rightsLen -= credLen
+ }
+ // FDs are 32 bit (4 byte) ints.
+ numRights := rightsLen / 4
+ if numRights < 0 {
+ numRights = 0
+ }
+
+ r := EndpointReader{
+ Ctx: t,
+ Endpoint: s.ep,
+ Creds: wantCreds,
+ NumRights: numRights,
+ Peek: peek,
+ }
+ if senderRequested {
+ r.From = &tcpip.FullAddress{}
+ }
+
+ doRead := func() (int64, error) {
+ return dst.CopyOutFrom(t, &r)
+ }
+
+ // If MSG_TRUNC is set with a zero byte destination then we still need
+ // to read the message and discard it, or in the case where MSG_PEEK is
+ // set, leave it be. In both cases the full message length must be
+ // returned.
+ if trunc && dst.Addrs.NumBytes() == 0 {
+ doRead = func() (int64, error) {
+ err := r.Truncate()
+ // Always return zero for bytes read since the destination size is
+ // zero.
+ return 0, err
+ }
+
+ }
+
+ var total int64
+ if n, err := doRead(); err != syserror.ErrWouldBlock || dontWait {
+ var from linux.SockAddr
+ var fromLen uint32
+ if r.From != nil && len([]byte(r.From.Addr)) != 0 {
+ from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
+ }
+
+ if r.ControlTrunc {
+ msgFlags |= linux.MSG_CTRUNC
+ }
+
+ if err != nil || dontWait || !waitAll || isPacket || n >= dst.NumBytes() {
+ if isPacket && n < int64(r.MsgSize) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+
+ if trunc {
+ n = int64(r.MsgSize)
+ }
+
+ return int(n), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+ }
+
+ // Don't overwrite any data we received.
+ dst = dst.DropFirst64(n)
+ total += n
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ for {
+ if n, err := doRead(); err != syserror.ErrWouldBlock {
+ var from linux.SockAddr
+ var fromLen uint32
+ if r.From != nil {
+ from, fromLen = netstack.ConvertAddress(linux.AF_UNIX, *r.From)
+ }
+
+ if r.ControlTrunc {
+ msgFlags |= linux.MSG_CTRUNC
+ }
+
+ if trunc {
+ // n and r.MsgSize are the same for streams.
+ total += int64(r.MsgSize)
+ } else {
+ total += n
+ }
+
+ streamPeerClosed := s.stype == linux.SOCK_STREAM && n == 0 && err == nil
+ if err != nil || !waitAll || isPacket || n >= dst.NumBytes() || streamPeerClosed {
+ if total > 0 {
+ err = nil
+ }
+ if isPacket && n < int64(r.MsgSize) {
+ msgFlags |= linux.MSG_TRUNC
+ }
+ return int(total), msgFlags, from, fromLen, socket.ControlMessages{Unix: r.Control}, syserr.FromError(err)
+ }
+
+ // Don't overwrite any data we received.
+ dst = dst.DropFirst64(n)
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if total > 0 {
+ err = nil
+ }
+ if err == syserror.ETIMEDOUT {
+ return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain
+ }
+ return int(total), msgFlags, nil, 0, socket.ControlMessages{}, syserr.FromError(err)
+ }
+ }
+}
+
+// State implements socket.Socket.State.
+func (s *socketOpsCommon) State() uint32 {
+ return s.ep.State()
+}
+
+// Type implements socket.Socket.Type.
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
+ // Unix domain sockets always have a protocol of 0.
+ return linux.AF_UNIX, s.stype, 0
+}
+
+// provider is a unix domain socket provider.
+type provider struct{}
+
+// Socket returns a new unix domain socket.
+func (*provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // Create the endpoint and socket.
+ var ep transport.Endpoint
+ switch stype {
+ case linux.SOCK_DGRAM, linux.SOCK_RAW:
+ ep = transport.NewConnectionless(t)
+ case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
+ ep = transport.NewConnectioned(t, stype, t.Kernel())
+ default:
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return New(t, ep, stype), nil
+}
+
+// Pair creates a new pair of AF_UNIX connected sockets.
+func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, nil, syserr.ErrProtocolNotSupported
+ }
+
+ switch stype {
+ case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET, linux.SOCK_RAW:
+ // Ok
+ default:
+ return nil, nil, syserr.ErrInvalidArgument
+ }
+
+ // Create the endpoints and sockets.
+ ep1, ep2 := transport.NewPair(t, stype, t.Kernel())
+ s1 := New(t, ep1, stype)
+ s2 := New(t, ep2, stype)
+
+ return s1, s2, nil
+}
+
+func init() {
+ socket.RegisterProvider(linux.AF_UNIX, &provider{})
+ socket.RegisterProviderVFS2(linux.AF_UNIX, &providerVFS2{})
+}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
new file mode 100644
index 000000000..ff2149250
--- /dev/null
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -0,0 +1,371 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SocketVFS2 implements socket.SocketVFS2 (and by extension,
+// vfs.FileDescriptionImpl) for Unix sockets.
+type SocketVFS2 struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+ vfs.LockFD
+
+ socketOpsCommon
+}
+
+var _ = socket.SocketVFS2(&SocketVFS2{})
+
+// NewSockfsFile creates a new socket file in the global sockfs mount and
+// returns a corresponding file description.
+func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
+ mnt := t.Kernel().SocketMount()
+ d := sockfs.NewDentry(t.Credentials(), mnt)
+
+ fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{})
+ if err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return fd, nil
+}
+
+// NewFileDescription creates and returns a socket file description
+// corresponding to the given mount and dentry.
+func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry, locks *vfs.FileLocks) (*vfs.FileDescription, error) {
+ // You can create AF_UNIX, SOCK_RAW sockets. They're the same as
+ // SOCK_DGRAM and don't require CAP_NET_RAW.
+ if stype == linux.SOCK_RAW {
+ stype = linux.SOCK_DGRAM
+ }
+
+ sock := &SocketVFS2{
+ socketOpsCommon: socketOpsCommon{
+ ep: ep,
+ stype: stype,
+ },
+ }
+ sock.LockFD.Init(locks)
+ vfsfd := &sock.vfsfd
+ if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{
+ DenyPRead: true,
+ DenyPWrite: true,
+ UseDentryMetadata: true,
+ }); err != nil {
+ return nil, err
+ }
+ return vfsfd, nil
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+ return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
+}
+
+// blockingAccept implements a blocking version of accept(2), that is, if no
+// connections are ready to be accept, it will block until one becomes ready.
+func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.socketOpsCommon.EventRegister(&e, waiter.EventIn)
+ defer s.socketOpsCommon.EventUnregister(&e)
+
+ // Try to accept the connection; if it fails, then wait until we get a
+ // notification.
+ for {
+ if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ return ep, err
+ }
+
+ if err := t.Block(ch); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ }
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, err := s.ep.Accept()
+ if err != nil {
+ if err != syserr.ErrWouldBlock || !blocking {
+ return 0, nil, 0, err
+ }
+
+ var err *syserr.Error
+ ep, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ ns, err := NewSockfsFile(t, ep, s.stype)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ defer ns.DecRef()
+
+ if flags&linux.SOCK_NONBLOCK != 0 {
+ ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK)
+ }
+
+ var addr linux.SockAddr
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer.
+ var err *syserr.Error
+ addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ })
+ if e != nil {
+ return 0, nil, 0, syserr.FromError(e)
+ }
+
+ t.Kernel().RecordSocketVFS2(ns)
+ return fd, addr, addrLen, nil
+}
+
+// Bind implements the linux syscall bind(2) for unix sockets.
+func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ p, e := extractPath(sockaddr)
+ if e != nil {
+ return e
+ }
+
+ bep, ok := s.ep.(transport.BoundEndpoint)
+ if !ok {
+ // This socket can't be bound.
+ return syserr.ErrInvalidArgument
+ }
+
+ return s.ep.Bind(tcpip.FullAddress{Addr: tcpip.Address(p)}, func() *syserr.Error {
+ // Is it abstract?
+ if p[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return syserr.ErrInvalidEndpointState
+ }
+ if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil {
+ // syserr.ErrPortInUse corresponds to EADDRINUSE.
+ return syserr.ErrPortInUse
+ }
+ } else {
+ path := fspath.Parse(p)
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ relPath := !path.Absolute
+ if relPath {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ }
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ }
+ stat, err := s.vfsfd.Stat(t, vfs.StatOptions{Mask: linux.STATX_MODE})
+ if err != nil {
+ return syserr.FromError(err)
+ }
+ err = t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{
+ // File permissions correspond to net/unix/af_unix.c:unix_bind.
+ Mode: linux.FileMode(linux.S_IFSOCK | uint(stat.Mode)&^t.FSContext().Umask()),
+ Endpoint: bep,
+ })
+ if err == syserror.EEXIST {
+ return syserr.ErrAddressInUse
+ }
+ return syserr.FromError(err)
+ }
+
+ return nil
+ })
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return netstack.Ioctl(ctx, s.ep, uio, args)
+}
+
+// PRead implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Read implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &EndpointReader{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ NumRights: 0,
+ Peek: false,
+ From: nil,
+ })
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ t := kernel.TaskFromContext(ctx)
+ ctrl := control.New(t, s.ep, nil)
+
+ if src.NumBytes() == 0 {
+ nInt, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil)
+ return int64(nInt), err.ToError()
+ }
+
+ return src.CopyInTo(ctx, &EndpointWriter{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ Control: ctrl,
+ To: nil,
+ })
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.socketOpsCommon.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.socketOpsCommon.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SocketVFS2) EventUnregister(e *waiter.Entry) {
+ s.socketOpsCommon.EventUnregister(e)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ return netstack.SetSockOpt(t, s, s.ep, level, name, optVal)
+}
+
+// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX.
+func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error {
+ return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block)
+}
+
+// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX.
+func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error {
+ return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence)
+}
+
+// providerVFS2 is a unix domain socket provider for VFS2.
+type providerVFS2 struct{}
+
+func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // Create the endpoint and socket.
+ var ep transport.Endpoint
+ switch stype {
+ case linux.SOCK_DGRAM, linux.SOCK_RAW:
+ ep = transport.NewConnectionless(t)
+ case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
+ ep = transport.NewConnectioned(t, stype, t.Kernel())
+ default:
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ f, err := NewSockfsFile(t, ep, stype)
+ if err != nil {
+ ep.Close()
+ return nil, err
+ }
+ return f, nil
+}
+
+// Pair creates a new pair of AF_UNIX connected sockets.
+func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, nil, syserr.ErrProtocolNotSupported
+ }
+
+ switch stype {
+ case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET, linux.SOCK_RAW:
+ // Ok
+ default:
+ return nil, nil, syserr.ErrInvalidArgument
+ }
+
+ // Create the endpoints and sockets.
+ ep1, ep2 := transport.NewPair(t, stype, t.Kernel())
+ s1, err := NewSockfsFile(t, ep1, stype)
+ if err != nil {
+ ep1.Close()
+ ep2.Close()
+ return nil, nil, err
+ }
+ s2, err := NewSockfsFile(t, ep2, stype)
+ if err != nil {
+ s1.DecRef()
+ ep2.Close()
+ return nil, nil, err
+ }
+
+ return s1, s2, nil
+}