summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/BUILD37
-rw-r--r--pkg/sentry/socket/control/BUILD39
-rw-r--r--pkg/sentry/socket/control/control.go370
-rw-r--r--pkg/sentry/socket/epsocket/BUILD61
-rw-r--r--pkg/sentry/socket/epsocket/device.go20
-rw-r--r--pkg/sentry/socket/epsocket/epsocket.go1230
-rw-r--r--pkg/sentry/socket/epsocket/provider.go113
-rw-r--r--pkg/sentry/socket/epsocket/save_restore.go27
-rw-r--r--pkg/sentry/socket/epsocket/stack.go132
-rw-r--r--pkg/sentry/socket/hostinet/BUILD53
-rw-r--r--pkg/sentry/socket/hostinet/device.go19
-rw-r--r--pkg/sentry/socket/hostinet/hostinet.go17
-rw-r--r--pkg/sentry/socket/hostinet/save_restore.go20
-rw-r--r--pkg/sentry/socket/hostinet/socket.go562
-rw-r--r--pkg/sentry/socket/hostinet/socket_unsafe.go138
-rw-r--r--pkg/sentry/socket/hostinet/stack.go244
-rw-r--r--pkg/sentry/socket/netlink/BUILD47
-rw-r--r--pkg/sentry/socket/netlink/message.go159
-rw-r--r--pkg/sentry/socket/netlink/port/BUILD28
-rw-r--r--pkg/sentry/socket/netlink/port/port.go114
-rw-r--r--pkg/sentry/socket/netlink/port/port_test.go82
-rw-r--r--pkg/sentry/socket/netlink/provider.go104
-rw-r--r--pkg/sentry/socket/netlink/route/BUILD33
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go189
-rw-r--r--pkg/sentry/socket/netlink/socket.go517
-rw-r--r--pkg/sentry/socket/rpcinet/BUILD59
-rw-r--r--pkg/sentry/socket/rpcinet/conn/BUILD17
-rw-r--r--pkg/sentry/socket/rpcinet/conn/conn.go167
-rw-r--r--pkg/sentry/socket/rpcinet/device.go19
-rw-r--r--pkg/sentry/socket/rpcinet/notifier/BUILD15
-rw-r--r--pkg/sentry/socket/rpcinet/notifier/notifier.go230
-rw-r--r--pkg/sentry/socket/rpcinet/rpcinet.go16
-rw-r--r--pkg/sentry/socket/rpcinet/socket.go567
-rw-r--r--pkg/sentry/socket/rpcinet/stack.go175
-rw-r--r--pkg/sentry/socket/rpcinet/stack_unsafe.go193
-rw-r--r--pkg/sentry/socket/rpcinet/syscall_rpc.proto351
-rw-r--r--pkg/sentry/socket/socket.go205
-rw-r--r--pkg/sentry/socket/unix/BUILD48
-rw-r--r--pkg/sentry/socket/unix/device.go20
-rw-r--r--pkg/sentry/socket/unix/io.go88
-rw-r--r--pkg/sentry/socket/unix/unix.go571
41 files changed, 7096 insertions, 0 deletions
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
new file mode 100644
index 000000000..87e32df37
--- /dev/null
+++ b/pkg/sentry/socket/BUILD
@@ -0,0 +1,37 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "socket_state",
+ srcs = [
+ "socket.go",
+ ],
+ out = "socket_state_autogen.go",
+ package = "socket",
+)
+
+go_library(
+ name = "socket",
+ srcs = [
+ "socket.go",
+ "socket_state_autogen.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/context",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/kdefs",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/usermem",
+ "//pkg/state",
+ "//pkg/syserr",
+ "//pkg/tcpip/transport/unix",
+ ],
+)
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
new file mode 100644
index 000000000..25de2f655
--- /dev/null
+++ b/pkg/sentry/socket/control/BUILD
@@ -0,0 +1,39 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "control_state",
+ srcs = [
+ "control.go",
+ ],
+ out = "control_state.go",
+ imports = [
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs",
+ ],
+ package = "control",
+)
+
+go_library(
+ name = "control",
+ srcs = [
+ "control.go",
+ "control_state.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/control",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/sentry/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/kernel/kdefs",
+ "//pkg/sentry/usermem",
+ "//pkg/state",
+ "//pkg/syserror",
+ "//pkg/tcpip/transport/unix",
+ ],
+)
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
new file mode 100644
index 000000000..cb34cbc85
--- /dev/null
+++ b/pkg/sentry/socket/control/control.go
@@ -0,0 +1,370 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package control provides internal representations of socket control
+// messages.
+package control
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix"
+)
+
+const maxInt = int(^uint(0) >> 1)
+
+// SCMCredentials represents a SCM_CREDENTIALS socket control message.
+type SCMCredentials interface {
+ unix.CredentialsControlMessage
+
+ // Credentials returns properly namespaced values for the sender's pid, uid
+ // and gid.
+ Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, auth.GID)
+}
+
+// SCMRights represents a SCM_RIGHTS socket control message.
+type SCMRights interface {
+ unix.RightsControlMessage
+
+ // Files returns up to max RightsFiles.
+ Files(ctx context.Context, max int) RightsFiles
+}
+
+// RightsFiles represents a SCM_RIGHTS socket control message. A reference is
+// maintained for each fs.File and is release either when an FD is created or
+// when the Release method is called.
+type RightsFiles []*fs.File
+
+// NewSCMRights creates a new SCM_RIGHTS socket control message representation
+// using local sentry FDs.
+func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) {
+ files := make(RightsFiles, 0, len(fds))
+ for _, fd := range fds {
+ file, _ := t.FDMap().GetDescriptor(kdefs.FD(fd))
+ if file == nil {
+ files.Release()
+ return nil, syserror.EBADF
+ }
+ files = append(files, file)
+ }
+ return &files, nil
+}
+
+// Files implements SCMRights.Files.
+func (fs *RightsFiles) Files(ctx context.Context, max int) RightsFiles {
+ n := max
+ if l := len(*fs); n > l {
+ n = l
+ }
+ rf := (*fs)[:n]
+ *fs = (*fs)[n:]
+ return rf
+}
+
+// Clone implements unix.RightsControlMessage.Clone.
+func (fs *RightsFiles) Clone() unix.RightsControlMessage {
+ nfs := append(RightsFiles(nil), *fs...)
+ for _, nf := range nfs {
+ nf.IncRef()
+ }
+ return &nfs
+}
+
+// Release implements unix.RightsControlMessage.Release.
+func (fs *RightsFiles) Release() {
+ for _, f := range *fs {
+ f.DecRef()
+ }
+ *fs = nil
+}
+
+// rightsFDs gets up to the specified maximum number of FDs.
+func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) []int32 {
+ files := rights.Files(t, max)
+ fds := make([]int32, 0, len(files))
+ for i := 0; i < max && len(files) > 0; i++ {
+ fd, err := t.FDMap().NewFDFrom(0, files[0], kernel.FDFlags{cloexec}, t.ThreadGroup().Limits())
+ files[0].DecRef()
+ files = files[1:]
+ if err != nil {
+ t.Warningf("Error inserting FD: %v", err)
+ // This is what Linux does.
+ break
+ }
+
+ fds = append(fds, int32(fd))
+ }
+ return fds
+}
+
+// PackRights packs as many FDs as will fit into the unused capacity of buf.
+func PackRights(t *kernel.Task, rights SCMRights, cloexec bool, buf []byte) []byte {
+ maxFDs := (cap(buf) - len(buf) - linux.SizeOfControlMessageHeader) / 4
+ // Linux does not return any FDs if none fit.
+ if maxFDs <= 0 {
+ return buf
+ }
+ fds := rightsFDs(t, rights, cloexec, maxFDs)
+ align := t.Arch().Width()
+ return putCmsg(buf, linux.SCM_RIGHTS, align, fds)
+}
+
+// scmCredentials represents an SCM_CREDENTIALS socket control message.
+type scmCredentials struct {
+ t *kernel.Task
+ kuid auth.KUID
+ kgid auth.KGID
+}
+
+// NewSCMCredentials creates a new SCM_CREDENTIALS socket control message
+// representation.
+func NewSCMCredentials(t *kernel.Task, cred linux.ControlMessageCredentials) (SCMCredentials, error) {
+ tcred := t.Credentials()
+ kuid, err := tcred.UseUID(auth.UID(cred.UID))
+ if err != nil {
+ return nil, err
+ }
+ kgid, err := tcred.UseGID(auth.GID(cred.GID))
+ if err != nil {
+ return nil, err
+ }
+ if kernel.ThreadID(cred.PID) != t.ThreadGroup().ID() && !t.HasCapabilityIn(linux.CAP_SYS_ADMIN, t.PIDNamespace().UserNamespace()) {
+ return nil, syserror.EPERM
+ }
+ return &scmCredentials{t, kuid, kgid}, nil
+}
+
+// Equals implements unix.CredentialsControlMessage.Equals.
+func (c *scmCredentials) Equals(oc unix.CredentialsControlMessage) bool {
+ if oc, _ := oc.(*scmCredentials); oc != nil && *c == *oc {
+ return true
+ }
+ return false
+}
+
+func putUint64(buf []byte, n uint64) []byte {
+ usermem.ByteOrder.PutUint64(buf[len(buf):len(buf)+8], n)
+ return buf[:len(buf)+8]
+}
+
+func putUint32(buf []byte, n uint32) []byte {
+ usermem.ByteOrder.PutUint32(buf[len(buf):len(buf)+4], n)
+ return buf[:len(buf)+4]
+}
+
+// putCmsg writes a control message header and as much data as will fit into
+// the unused capacity of a buffer.
+func putCmsg(buf []byte, msgType uint32, align uint, data []int32) []byte {
+ space := AlignDown(cap(buf)-len(buf), 4)
+
+ // We can't write to space that doesn't exist, so if we are going to align
+ // the available space, we must align down.
+ //
+ // align must be >= 4 and each data int32 is 4 bytes. The length of the
+ // header is already aligned, so if we align to the with of the data there
+ // are two cases:
+ // 1. The aligned length is less than the length of the header. The
+ // unaligned length was also less than the length of the header, so we
+ // can't write anything.
+ // 2. The aligned length is greater than or equal to the length of the
+ // header. We can write the header plus zero or more datas. We can't write
+ // a partial int32, so the length of the message will be
+ // min(aligned length, header + datas).
+ if space < linux.SizeOfControlMessageHeader {
+ return buf
+ }
+
+ length := 4*len(data) + linux.SizeOfControlMessageHeader
+ if length > space {
+ length = space
+ }
+ buf = putUint64(buf, uint64(length))
+ buf = putUint32(buf, linux.SOL_SOCKET)
+ buf = putUint32(buf, msgType)
+ for _, d := range data {
+ if len(buf)+4 > cap(buf) {
+ break
+ }
+ buf = putUint32(buf, uint32(d))
+ }
+ return alignSlice(buf, align)
+}
+
+// Credentials implements SCMCredentials.Credentials.
+func (c *scmCredentials) Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) {
+ // "When a process's user and group IDs are passed over a UNIX domain
+ // socket to a process in a different user namespace (see the description
+ // of SCM_CREDENTIALS in unix(7)), they are translated into the
+ // corresponding values as per the receiving process's user and group ID
+ // mappings." - user_namespaces(7)
+ pid := t.PIDNamespace().IDOfTask(c.t)
+ uid := c.kuid.In(t.UserNamespace()).OrOverflow()
+ gid := c.kgid.In(t.UserNamespace()).OrOverflow()
+
+ return pid, uid, gid
+}
+
+// PackCredentials packs the credentials in the control message (or default
+// credentials if none) into a buffer.
+func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte) []byte {
+ align := t.Arch().Width()
+
+ // Default credentials if none are available.
+ pid := kernel.ThreadID(0)
+ uid := auth.UID(auth.NobodyKUID)
+ gid := auth.GID(auth.NobodyKGID)
+
+ if creds != nil {
+ pid, uid, gid = creds.Credentials(t)
+ }
+ c := []int32{int32(pid), int32(uid), int32(gid)}
+ return putCmsg(buf, linux.SCM_CREDENTIALS, align, c)
+}
+
+// AlignUp rounds a length up to an alignment. align must be a power of 2.
+func AlignUp(length int, align uint) int {
+ return (length + int(align) - 1) & ^(int(align) - 1)
+}
+
+// AlignDown rounds a down to an alignment. align must be a power of 2.
+func AlignDown(length int, align uint) int {
+ return length & ^(int(align) - 1)
+}
+
+// alignSlice extends a slice's length (up to the capacity) to align it.
+func alignSlice(buf []byte, align uint) []byte {
+ aligned := AlignUp(len(buf), align)
+ if aligned > cap(buf) {
+ // Linux allows unaligned data if there isn't room for alignment.
+ // Since there isn't room for alignment, there isn't room for any
+ // additional messages either.
+ return buf
+ }
+ return buf[:aligned]
+}
+
+// Parse parses a raw socket control message into portable objects.
+func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (unix.ControlMessages, error) {
+ var (
+ fds linux.ControlMessageRights
+
+ haveCreds bool
+ creds linux.ControlMessageCredentials
+ )
+
+ for i := 0; i < len(buf); {
+ if i+linux.SizeOfControlMessageHeader > len(buf) {
+ return unix.ControlMessages{}, syserror.EINVAL
+ }
+
+ var h linux.ControlMessageHeader
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h)
+
+ if h.Length < uint64(linux.SizeOfControlMessageHeader) {
+ return unix.ControlMessages{}, syserror.EINVAL
+ }
+ if h.Length > uint64(len(buf)-i) {
+ return unix.ControlMessages{}, syserror.EINVAL
+ }
+ if h.Level != linux.SOL_SOCKET {
+ return unix.ControlMessages{}, syserror.EINVAL
+ }
+
+ i += linux.SizeOfControlMessageHeader
+ length := int(h.Length) - linux.SizeOfControlMessageHeader
+
+ // The use of t.Arch().Width() is analogous to Linux's use of
+ // sizeof(long) in CMSG_ALIGN.
+ width := t.Arch().Width()
+
+ switch h.Type {
+ case linux.SCM_RIGHTS:
+ rightsSize := AlignDown(length, linux.SizeOfControlMessageRight)
+ numRights := rightsSize / linux.SizeOfControlMessageRight
+
+ if len(fds)+numRights > linux.SCM_MAX_FD {
+ return unix.ControlMessages{}, syserror.EINVAL
+ }
+
+ for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight {
+ fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
+ }
+
+ i += AlignUp(length, width)
+
+ case linux.SCM_CREDENTIALS:
+ if length < linux.SizeOfControlMessageCredentials {
+ return unix.ControlMessages{}, syserror.EINVAL
+ }
+
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds)
+ haveCreds = true
+ i += AlignUp(length, width)
+
+ default:
+ // Unknown message type.
+ return unix.ControlMessages{}, syserror.EINVAL
+ }
+ }
+
+ var credentials SCMCredentials
+ if haveCreds {
+ var err error
+ if credentials, err = NewSCMCredentials(t, creds); err != nil {
+ return unix.ControlMessages{}, err
+ }
+ } else {
+ credentials = makeCreds(t, socketOrEndpoint)
+ }
+
+ var rights SCMRights
+ if len(fds) > 0 {
+ var err error
+ if rights, err = NewSCMRights(t, fds); err != nil {
+ return unix.ControlMessages{}, err
+ }
+ }
+
+ if credentials == nil && rights == nil {
+ return unix.ControlMessages{}, nil
+ }
+
+ return unix.ControlMessages{Credentials: credentials, Rights: rights}, nil
+}
+
+func makeCreds(t *kernel.Task, socketOrEndpoint interface{}) SCMCredentials {
+ if t == nil || socketOrEndpoint == nil {
+ return nil
+ }
+ if cr, ok := socketOrEndpoint.(unix.Credentialer); ok && (cr.Passcred() || cr.ConnectedPasscred()) {
+ tcred := t.Credentials()
+ return &scmCredentials{t, tcred.EffectiveKUID, tcred.EffectiveKGID}
+ }
+ return nil
+}
+
+// New creates default control messages if needed.
+func New(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRights) unix.ControlMessages {
+ return unix.ControlMessages{
+ Credentials: makeCreds(t, socketOrEndpoint),
+ Rights: rights,
+ }
+}
diff --git a/pkg/sentry/socket/epsocket/BUILD b/pkg/sentry/socket/epsocket/BUILD
new file mode 100644
index 000000000..0e463a92a
--- /dev/null
+++ b/pkg/sentry/socket/epsocket/BUILD
@@ -0,0 +1,61 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "epsocket_state",
+ srcs = [
+ "epsocket.go",
+ "save_restore.go",
+ "stack.go",
+ ],
+ out = "epsocket_state.go",
+ package = "epsocket",
+)
+
+go_library(
+ name = "epsocket",
+ srcs = [
+ "device.go",
+ "epsocket.go",
+ "epsocket_state.go",
+ "provider.go",
+ "save_restore.go",
+ "stack.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/epsocket",
+ visibility = [
+ "//pkg/sentry:internal",
+ ],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/log",
+ "//pkg/refs",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/context",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/kdefs",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/safemem",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/usermem",
+ "//pkg/state",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/network/ipv4",
+ "//pkg/tcpip/network/ipv6",
+ "//pkg/tcpip/stack",
+ "//pkg/tcpip/transport/tcp",
+ "//pkg/tcpip/transport/udp",
+ "//pkg/tcpip/transport/unix",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/socket/epsocket/device.go b/pkg/sentry/socket/epsocket/device.go
new file mode 100644
index 000000000..17f2c9559
--- /dev/null
+++ b/pkg/sentry/socket/epsocket/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package epsocket
+
+import "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+
+// epsocketDevice is the endpoint socket virtual device.
+var epsocketDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/epsocket/epsocket.go b/pkg/sentry/socket/epsocket/epsocket.go
new file mode 100644
index 000000000..3fc3ea58f
--- /dev/null
+++ b/pkg/sentry/socket/epsocket/epsocket.go
@@ -0,0 +1,1230 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package epsocket provides an implementation of the socket.Socket interface
+// that is backed by a tcpip.Endpoint.
+//
+// It does not depend on any particular endpoint implementation, and thus can
+// be used to expose certain endpoints to the sentry while leaving others out,
+// for example, TCP endpoints and Unix-domain endpoints.
+//
+// Lock ordering: netstack => mm: ioSequencePayload copies user memory inside
+// tcpip.Endpoint.Write(). Netstack is allowed to (and does) hold locks during
+// this operation.
+package epsocket
+
+import (
+ "bytes"
+ "math"
+ "strings"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/inet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+const sizeOfInt32 int = 4
+
+// ntohs converts a 16-bit number from network byte order to host byte order. It
+// assumes that the host is little endian.
+func ntohs(v uint16) uint16 {
+ return v<<8 | v>>8
+}
+
+// htons converts a 16-bit number from host byte order to network byte order. It
+// assumes that the host is little endian.
+func htons(v uint16) uint16 {
+ return ntohs(v)
+}
+
+// commonEndpoint represents the intersection of a tcpip.Endpoint and a
+// unix.Endpoint.
+type commonEndpoint interface {
+ // GetLocalAddress implements tcpip.Endpoint.GetLocalAddress and
+ // unix.Endpoint.GetLocalAddress.
+ GetLocalAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // GetRemoteAddress implements tcpip.Endpoint.GetRemoteAddress and
+ // unix.Endpoint.GetRemoteAddress.
+ GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error)
+
+ // Readiness implements tcpip.Endpoint.Readiness and
+ // unix.Endpoint.Readiness.
+ Readiness(mask waiter.EventMask) waiter.EventMask
+
+ // SetSockOpt implements tcpip.Endpoint.SetSockOpt and
+ // unix.Endpoint.SetSockOpt.
+ SetSockOpt(interface{}) *tcpip.Error
+
+ // GetSockOpt implements tcpip.Endpoint.GetSockOpt and
+ // unix.Endpoint.GetSockOpt.
+ GetSockOpt(interface{}) *tcpip.Error
+}
+
+// SocketOperations encapsulates all the state needed to represent a network stack
+// endpoint in the kernel context.
+type SocketOperations struct {
+ socket.ReceiveTimeout
+ fsutil.PipeSeek `state:"nosave"`
+ fsutil.NotDirReaddir `state:"nosave"`
+ fsutil.NoFsync `state:"nosave"`
+ fsutil.NoopFlush `state:"nosave"`
+ fsutil.NoMMap `state:"nosave"`
+ *waiter.Queue
+
+ family int
+ stack inet.Stack
+ Endpoint tcpip.Endpoint
+ skType unix.SockType
+
+ // readMu protects access to readView, control, and sender.
+ readMu sync.Mutex `state:"nosave"`
+ readView buffer.View
+ sender tcpip.FullAddress
+}
+
+// New creates a new endpoint socket.
+func New(t *kernel.Task, family int, skType unix.SockType, queue *waiter.Queue, endpoint tcpip.Endpoint) *fs.File {
+ dirent := socket.NewDirent(t, epsocketDevice)
+ return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true}, &SocketOperations{
+ Queue: queue,
+ family: family,
+ stack: t.NetworkContext(),
+ Endpoint: endpoint,
+ skType: skType,
+ })
+}
+
+var sockAddrInetSize = int(binary.Size(linux.SockAddrInet{}))
+var sockAddrInet6Size = int(binary.Size(linux.SockAddrInet6{}))
+
+// GetAddress reads an sockaddr struct from the given address and converts it
+// to the FullAddress format. It supports AF_UNIX, AF_INET and AF_INET6
+// addresses.
+func GetAddress(sfamily int, addr []byte) (tcpip.FullAddress, *syserr.Error) {
+ // Make sure we have at least 2 bytes for the address family.
+ if len(addr) < 2 {
+ return tcpip.FullAddress{}, syserr.ErrInvalidArgument
+ }
+
+ family := usermem.ByteOrder.Uint16(addr)
+ if family != uint16(sfamily) {
+ return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
+ }
+
+ // Get the rest of the fields based on the address family.
+ switch family {
+ case linux.AF_UNIX:
+ path := addr[2:]
+ // Drop the terminating NUL (if one exists) and everything after it.
+ // Skip the first byte, which is NUL for abstract paths.
+ if len(path) > 1 {
+ if n := bytes.IndexByte(path[1:], 0); n >= 0 {
+ path = path[:n+1]
+ }
+ }
+ return tcpip.FullAddress{
+ Addr: tcpip.Address(path),
+ }, nil
+
+ case linux.AF_INET:
+ var a linux.SockAddrInet
+ if len(addr) < sockAddrInetSize {
+ return tcpip.FullAddress{}, syserr.ErrBadAddress
+ }
+ binary.Unmarshal(addr[:sockAddrInetSize], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: tcpip.Address(a.Addr[:]),
+ Port: ntohs(a.Port),
+ }
+ if out.Addr == "\x00\x00\x00\x00" {
+ out.Addr = ""
+ }
+ return out, nil
+
+ case linux.AF_INET6:
+ var a linux.SockAddrInet6
+ if len(addr) < sockAddrInet6Size {
+ return tcpip.FullAddress{}, syserr.ErrBadAddress
+ }
+ binary.Unmarshal(addr[:sockAddrInet6Size], usermem.ByteOrder, &a)
+
+ out := tcpip.FullAddress{
+ Addr: tcpip.Address(a.Addr[:]),
+ Port: ntohs(a.Port),
+ }
+ if isLinkLocal(out.Addr) {
+ out.NIC = tcpip.NICID(a.Scope_id)
+ }
+ if out.Addr == tcpip.Address(strings.Repeat("\x00", 16)) {
+ out.Addr = ""
+ }
+ return out, nil
+
+ default:
+ return tcpip.FullAddress{}, syserr.ErrAddressFamilyNotSupported
+ }
+}
+
+func (s *SocketOperations) isPacketBased() bool {
+ return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM
+}
+
+// fetchReadView updates the readView field of the socket if it's currently
+// empty. It assumes that the socket is locked.
+func (s *SocketOperations) fetchReadView() *syserr.Error {
+ if len(s.readView) > 0 {
+ return nil
+ }
+
+ s.readView = nil
+ s.sender = tcpip.FullAddress{}
+
+ v, err := s.Endpoint.Read(&s.sender)
+ if err != nil {
+ return syserr.TranslateNetstackError(err)
+ }
+
+ s.readView = v
+
+ return nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *SocketOperations) Release() {
+ s.Endpoint.Close()
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ n, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false)
+ if err == syserr.ErrWouldBlock {
+ return int64(n), syserror.ErrWouldBlock
+ }
+ if err != nil {
+ return 0, err.ToError()
+ }
+ return int64(n), nil
+}
+
+// ioSequencePayload implements tcpip.Payload. It copies user memory bytes on demand
+// based on the requested size.
+type ioSequencePayload struct {
+ ctx context.Context
+ src usermem.IOSequence
+}
+
+// Get implements tcpip.Payload.
+func (i *ioSequencePayload) Get(size int) ([]byte, *tcpip.Error) {
+ if size > i.Size() {
+ size = i.Size()
+ }
+ v := buffer.NewView(size)
+ if _, err := i.src.CopyIn(i.ctx, v); err != nil {
+ return nil, tcpip.ErrBadAddress
+ }
+ return v, nil
+}
+
+// Size implements tcpip.Payload.
+func (i *ioSequencePayload) Size() int {
+ return int(i.src.NumBytes())
+}
+
+// Write implements fs.FileOperations.Write.
+func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ f := &ioSequencePayload{ctx: ctx, src: src}
+ n, err := s.Endpoint.Write(f, tcpip.WriteOptions{})
+ if err == tcpip.ErrWouldBlock {
+ return int64(n), syserror.ErrWouldBlock
+ }
+ return int64(n), syserr.TranslateNetstackError(err).ToError()
+}
+
+// Readiness returns a mask of ready events for socket s.
+func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ r := s.Endpoint.Readiness(mask)
+
+ // Check our cached value iff the caller asked for readability and the
+ // endpoint itself is currently not readable.
+ if (mask & ^r & waiter.EventIn) != 0 {
+ s.readMu.Lock()
+ if len(s.readView) > 0 {
+ r |= waiter.EventIn
+ }
+ s.readMu.Unlock()
+ }
+
+ return r
+}
+
+// Connect implements the linux syscall connect(2) for sockets backed by
+// tpcip.Endpoint.
+func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ addr, err := GetAddress(s.family, sockaddr)
+ if err != nil {
+ return err
+ }
+
+ // Always return right away in the non-blocking case.
+ if !blocking {
+ return syserr.TranslateNetstackError(s.Endpoint.Connect(addr))
+ }
+
+ // Register for notification when the endpoint becomes writable, then
+ // initiate the connection.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+
+ if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting {
+ return syserr.TranslateNetstackError(err)
+ }
+
+ // It's pending, so we have to wait for a notification, and fetch the
+ // result once the wait completes.
+ if err := t.Block(ch); err != nil {
+ return syserr.FromError(err)
+ }
+
+ // Call Connect() again after blocking to find connect's result.
+ return syserr.TranslateNetstackError(s.Endpoint.Connect(addr))
+}
+
+// Bind implements the linux syscall bind(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ addr, err := GetAddress(s.family, sockaddr)
+ if err != nil {
+ return err
+ }
+
+ // Issue the bind request to the endpoint.
+ return syserr.TranslateNetstackError(s.Endpoint.Bind(addr, nil))
+}
+
+// Listen implements the linux syscall listen(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ return syserr.TranslateNetstackError(s.Endpoint.Listen(backlog))
+}
+
+// blockingAccept implements a blocking version of accept(2), that is, if no
+// connections are ready to be accept, it will block until one becomes ready.
+func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ // Try to accept the connection again; if it fails, then wait until we
+ // get a notification.
+ for {
+ if ep, wq, err := s.Endpoint.Accept(); err != tcpip.ErrWouldBlock {
+ return ep, wq, syserr.TranslateNetstackError(err)
+ }
+
+ if err := t.Block(ch); err != nil {
+ return nil, nil, syserr.FromError(err)
+ }
+ }
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, wq, err := s.Endpoint.Accept()
+ if err != nil {
+ if err != tcpip.ErrWouldBlock || !blocking {
+ return 0, nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ var err *syserr.Error
+ ep, wq, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ ns := New(t, s.family, s.skType, wq, ep)
+ defer ns.DecRef()
+
+ if flags&linux.SOCK_NONBLOCK != 0 {
+ flags := ns.Flags()
+ flags.NonBlocking = true
+ ns.SetFlags(flags.Settable())
+ }
+
+ var addr interface{}
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer and write it to peer slice.
+ var err *syserr.Error
+ addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fdFlags := kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ }
+ fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits())
+
+ return fd, addr, addrLen, syserr.FromError(e)
+}
+
+// ConvertShutdown converts Linux shutdown flags into tcpip shutdown flags.
+func ConvertShutdown(how int) (tcpip.ShutdownFlags, *syserr.Error) {
+ var f tcpip.ShutdownFlags
+ switch how {
+ case linux.SHUT_RD:
+ f = tcpip.ShutdownRead
+ case linux.SHUT_WR:
+ f = tcpip.ShutdownWrite
+ case linux.SHUT_RDWR:
+ f = tcpip.ShutdownRead | tcpip.ShutdownWrite
+ default:
+ return 0, syserr.ErrInvalidArgument
+ }
+ return f, nil
+}
+
+// Shutdown implements the linux syscall shutdown(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ f, err := ConvertShutdown(how)
+ if err != nil {
+ return err
+ }
+
+ // Issue shutdown request.
+ return syserr.TranslateNetstackError(s.Endpoint.Shutdown(f))
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (interface{}, *syserr.Error) {
+ return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen)
+}
+
+// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
+// sockets backed by a commonEndpoint.
+func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType unix.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+ switch level {
+ case syscall.SOL_SOCKET:
+ switch name {
+ case linux.SO_TYPE:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ return int32(skType), nil
+
+ case linux.SO_ERROR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Get the last error and convert it.
+ err := ep.GetSockOpt(tcpip.ErrorOption{})
+ if err == nil {
+ return int32(0), nil
+ }
+ return int32(syserr.ToLinux(syserr.TranslateNetstackError(err)).Number()), nil
+
+ case linux.SO_PEERCRED:
+ if family != linux.AF_UNIX || outLen < syscall.SizeofUcred {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ tcred := t.Credentials()
+ return syscall.Ucred{
+ Pid: int32(t.ThreadGroup().ID()),
+ Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()),
+ Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()),
+ }, nil
+
+ case linux.SO_PASSCRED:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.PasscredOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.SO_SNDBUF:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var size tcpip.SendBufferSizeOption
+ if err := ep.GetSockOpt(&size); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ if size > math.MaxInt32 {
+ size = math.MaxInt32
+ }
+
+ return int32(size), nil
+
+ case linux.SO_RCVBUF:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var size tcpip.ReceiveBufferSizeOption
+ if err := ep.GetSockOpt(&size); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ if size > math.MaxInt32 {
+ size = math.MaxInt32
+ }
+
+ return int32(size), nil
+
+ case linux.SO_REUSEADDR:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.ReuseAddressOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case linux.SO_KEEPALIVE:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+ return int32(0), nil
+
+ case linux.SO_LINGER:
+ if outLen < syscall.SizeofLinger {
+ return nil, syserr.ErrInvalidArgument
+ }
+ return syscall.Linger{}, nil
+
+ case linux.SO_RCVTIMEO:
+ if outLen < linux.SizeOfTimeval {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return linux.NsecToTimeval(s.RecvTimeout()), nil
+ }
+
+ case syscall.SOL_TCP:
+ switch name {
+ case syscall.TCP_NODELAY:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.NoDelayOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+
+ case syscall.TCP_INFO:
+ var v tcpip.TCPInfoOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ // TODO: Translate fields once they are added to
+ // tcpip.TCPInfoOption.
+ info := linux.TCPInfo{}
+
+ // Linux truncates the output binary to outLen.
+ ib := binary.Marshal(nil, usermem.ByteOrder, &info)
+ if len(ib) > outLen {
+ ib = ib[:outLen]
+ }
+
+ return ib, nil
+ }
+
+ case syscall.SOL_IPV6:
+ switch name {
+ case syscall.IPV6_V6ONLY:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ var v tcpip.V6OnlyOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ return int32(v), nil
+ }
+ }
+
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
+}
+
+// SetSockOpt can be used to implement the linux syscall setsockopt(2) for
+// sockets backed by a commonEndpoint.
+func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error {
+ switch level {
+ case syscall.SOL_SOCKET:
+ switch name {
+ case linux.SO_SNDBUF:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.SendBufferSizeOption(v)))
+
+ case linux.SO_RCVBUF:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReceiveBufferSizeOption(v)))
+
+ case linux.SO_REUSEADDR:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReuseAddressOption(v)))
+
+ case linux.SO_PASSCRED:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v)))
+
+ case linux.SO_RCVTIMEO:
+ if len(optVal) < linux.SizeOfTimeval {
+ return syserr.ErrInvalidArgument
+ }
+
+ var v linux.Timeval
+ binary.Unmarshal(optVal[:linux.SizeOfTimeval], usermem.ByteOrder, &v)
+ s.SetRecvTimeout(v.ToNsecCapped())
+ return nil
+ }
+
+ case syscall.SOL_TCP:
+ switch name {
+ case syscall.TCP_NODELAY:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.NoDelayOption(v)))
+ }
+ case syscall.SOL_IPV6:
+ switch name {
+ case syscall.IPV6_V6ONLY:
+ if len(optVal) < sizeOfInt32 {
+ return syserr.ErrInvalidArgument
+ }
+
+ v := usermem.ByteOrder.Uint32(optVal)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.V6OnlyOption(v)))
+ }
+ }
+
+ // FIXME: Disallow IP-level multicast group options by
+ // default. These will need to be supported by appropriately plumbing
+ // the level through to the network stack (if at all). However, we
+ // still allow setting TTL, and multicast-enable/disable type options.
+ if level == 0 {
+ const (
+ _IP_ADD_MEMBERSHIP = 35
+ _MCAST_JOIN_GROUP = 42
+ )
+ if name == _IP_ADD_MEMBERSHIP || name == _MCAST_JOIN_GROUP {
+ return syserr.ErrInvalidArgument
+ }
+ }
+
+ // Default to the old behavior; hand off to network stack.
+ return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+}
+
+// isLinkLocal determines if the given IPv6 address is link-local. This is the
+// case when it has the fe80::/10 prefix. This check is used to determine when
+// the NICID is relevant for a given IPv6 address.
+func isLinkLocal(addr tcpip.Address) bool {
+ return len(addr) >= 2 && addr[0] == 0xfe && addr[1]&0xc0 == 0x80
+}
+
+// ConvertAddress converts the given address to a native format.
+func ConvertAddress(family int, addr tcpip.FullAddress) (interface{}, uint32) {
+ switch family {
+ case linux.AF_UNIX:
+ var out linux.SockAddrUnix
+ out.Family = linux.AF_UNIX
+ for i := 0; i < len([]byte(addr.Addr)); i++ {
+ out.Path[i] = int8(addr.Addr[i])
+ }
+ // Linux just returns the header for empty addresses.
+ if len(addr.Addr) == 0 {
+ return out, 2
+ }
+ // Linux returns the used length of the address struct (including the
+ // null terminator) for filesystem paths. The Family field is 2 bytes.
+ // It is sometimes allowed to exclude the null terminator if the
+ // address length is the max. Abstract paths always return the full
+ // length.
+ if out.Path[0] == 0 || len([]byte(addr.Addr)) == len(out.Path) {
+ return out, uint32(binary.Size(out))
+ }
+ return out, uint32(3 + len(addr.Addr))
+ case linux.AF_INET:
+ var out linux.SockAddrInet
+ copy(out.Addr[:], addr.Addr)
+ out.Family = linux.AF_INET
+ out.Port = htons(addr.Port)
+ return out, uint32(binary.Size(out))
+ case linux.AF_INET6:
+ var out linux.SockAddrInet6
+ if len(addr.Addr) == 4 {
+ // Copy address is v4-mapped format.
+ copy(out.Addr[12:], addr.Addr)
+ out.Addr[10] = 0xff
+ out.Addr[11] = 0xff
+ } else {
+ copy(out.Addr[:], addr.Addr)
+ }
+ out.Family = linux.AF_INET6
+ out.Port = htons(addr.Port)
+ if isLinkLocal(addr.Addr) {
+ out.Scope_id = uint32(addr.NIC)
+ }
+ return out, uint32(binary.Size(out))
+ default:
+ return nil, 0
+ }
+}
+
+// GetSockName implements the linux syscall getsockname(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ addr, err := s.Endpoint.GetLocalAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := ConvertAddress(s.family, addr)
+ return a, l, nil
+}
+
+// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ addr, err := s.Endpoint.GetRemoteAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := ConvertAddress(s.family, addr)
+ return a, l, nil
+}
+
+// coalescingRead is the fast path for non-blocking, non-peek, stream-based
+// case. It coalesces as many packets as possible before returning to the
+// caller.
+func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) {
+ var err *syserr.Error
+ var copied int
+
+ // Copy as many views as possible into the user-provided buffer.
+ for dst.NumBytes() != 0 {
+ err = s.fetchReadView()
+ if err != nil {
+ break
+ }
+
+ var n int
+ var e error
+ if discard {
+ n = len(s.readView)
+ if int64(n) > dst.NumBytes() {
+ n = int(dst.NumBytes())
+ }
+ } else {
+ n, e = dst.CopyOut(ctx, s.readView)
+ }
+ copied += n
+ s.readView.TrimFront(n)
+ dst = dst.DropFirst(n)
+ if e != nil {
+ err = syserr.FromError(e)
+ break
+ }
+ }
+
+ // If we managed to copy something, we must deliver it.
+ if copied > 0 {
+ return copied, nil
+ }
+
+ return 0, err
+}
+
+// nonBlockingRead issues a non-blocking read.
+func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, interface{}, uint32, *syserr.Error) {
+ isPacket := s.isPacketBased()
+
+ // Fast path for regular reads from stream (e.g., TCP) endpoints. Note
+ // that senderRequested is ignored for stream sockets.
+ if !peek && !isPacket {
+ // TCP sockets discard the data if MSG_TRUNC is set.
+ //
+ // This behavior is documented in man 7 tcp:
+ // Since version 2.4, Linux supports the use of MSG_TRUNC in the flags
+ // argument of recv(2) (and recvmsg(2)). This flag causes the received
+ // bytes of data to be discarded, rather than passed back in a
+ // caller-supplied buffer.
+ s.readMu.Lock()
+ n, err := s.coalescingRead(ctx, dst, trunc)
+ s.readMu.Unlock()
+ return n, nil, 0, err
+ }
+
+ s.readMu.Lock()
+ defer s.readMu.Unlock()
+
+ if err := s.fetchReadView(); err != nil {
+ return 0, nil, 0, err
+ }
+
+ if !isPacket && peek && trunc {
+ // MSG_TRUNC with MSG_PEEK on a TCP socket returns the
+ // amount that could be read.
+ var rql tcpip.ReceiveQueueSizeOption
+ if err := s.Endpoint.GetSockOpt(&rql); err != nil {
+ return 0, nil, 0, syserr.TranslateNetstackError(err)
+ }
+ available := len(s.readView) + int(rql)
+ bufLen := int(dst.NumBytes())
+ if available < bufLen {
+ return available, nil, 0, nil
+ }
+ return bufLen, nil, 0, nil
+ }
+
+ n, err := dst.CopyOut(ctx, s.readView)
+ var addr interface{}
+ var addrLen uint32
+ if isPacket && senderRequested {
+ addr, addrLen = ConvertAddress(s.family, s.sender)
+ }
+
+ if peek {
+ if l := len(s.readView); trunc && l > n {
+ // isPacket must be true.
+ return l, addr, addrLen, syserr.FromError(err)
+ }
+
+ if isPacket || err != nil {
+ return int(n), addr, addrLen, syserr.FromError(err)
+ }
+
+ // We need to peek beyond the first message.
+ dst = dst.DropFirst(n)
+ num, err := dst.CopyOutFrom(ctx, safemem.FromVecReaderFunc{func(dsts [][]byte) (int64, error) {
+ n, err := s.Endpoint.Peek(dsts)
+ if err != nil {
+ return int64(n), syserr.TranslateNetstackError(err).ToError()
+ }
+ return int64(n), nil
+ }})
+ n += int(num)
+ if err == syserror.ErrWouldBlock && n > 0 {
+ // We got some data, so no need to return an error.
+ err = nil
+ }
+ return int(n), nil, 0, syserr.FromError(err)
+ }
+
+ var msgLen int
+ if isPacket {
+ msgLen = len(s.readView)
+ s.readView = nil
+ } else {
+ msgLen = int(n)
+ s.readView.TrimFront(int(n))
+ }
+
+ if trunc {
+ return msgLen, addr, addrLen, syserr.FromError(err)
+ }
+
+ return int(n), addr, addrLen, syserr.FromError(err)
+}
+
+// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages unix.ControlMessages, err *syserr.Error) {
+ trunc := flags&linux.MSG_TRUNC != 0
+
+ peek := flags&linux.MSG_PEEK != 0
+ if senderRequested && !s.isPacketBased() {
+ // Stream sockets ignore the sender address.
+ senderRequested = false
+ }
+ n, senderAddr, senderAddrLen, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
+ if err != syserr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ return
+ }
+
+ // We'll have to block. Register for notifications and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ for {
+ n, senderAddr, senderAddrLen, err = s.nonBlockingRead(t, dst, peek, trunc, senderRequested)
+ if err != syserr.ErrWouldBlock {
+ return
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return 0, nil, 0, unix.ControlMessages{}, syserr.ErrTryAgain
+ }
+ return 0, nil, 0, unix.ControlMessages{}, syserr.FromError(err)
+ }
+ }
+}
+
+// SendMsg implements the linux syscall sendmsg(2) for sockets backed by
+// tcpip.Endpoint.
+func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages unix.ControlMessages) (int, *syserr.Error) {
+ // Reject control messages.
+ if !controlMessages.Empty() {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ var addr *tcpip.FullAddress
+ if len(to) > 0 {
+ addrBuf, err := GetAddress(s.family, to)
+ if err != nil {
+ return 0, err
+ }
+
+ addr = &addrBuf
+ }
+
+ v := buffer.NewView(int(src.NumBytes()))
+
+ // Copy all the data into the buffer.
+ if _, err := src.CopyIn(t, v); err != nil {
+ return 0, syserr.FromError(err)
+ }
+
+ opts := tcpip.WriteOptions{
+ To: addr,
+ More: flags&linux.MSG_MORE != 0,
+ EndOfRecord: flags&linux.MSG_EOR != 0,
+ }
+
+ n, err := s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ if err != tcpip.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ return int(n), syserr.TranslateNetstackError(err)
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+
+ v.TrimFront(int(n))
+ total := n
+ for {
+ n, err = s.Endpoint.Write(tcpip.SlicePayload(v), opts)
+ v.TrimFront(int(n))
+ total += n
+ if err != tcpip.ErrWouldBlock {
+ return int(total), syserr.TranslateNetstackError(err)
+ }
+
+ if err := t.Block(ch); err != nil {
+ return int(total), syserr.FromError(err)
+ }
+ }
+}
+
+// interfaceIoctl implements interface requests.
+func (s *SocketOperations) interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFReq) *syserr.Error {
+ var (
+ iface inet.Interface
+ index int32
+ found bool
+ )
+
+ // Find the relevant device.
+ for index, iface = range s.stack.Interfaces() {
+ if iface.Name == ifr.Name() {
+ found = true
+ break
+ }
+ }
+ if !found {
+ return syserr.ErrNoDevice
+ }
+
+ switch arg {
+ case syscall.SIOCGIFINDEX:
+ // Copy out the index to the data.
+ usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index))
+
+ case syscall.SIOCGIFHWADDR:
+ // Copy the hardware address out.
+ ifr.Data[0] = 6 // IEEE802.2 arp type.
+ ifr.Data[1] = 0
+ n := copy(ifr.Data[2:], iface.Addr)
+ for i := 2 + n; i < len(ifr.Data); i++ {
+ ifr.Data[i] = 0 // Clear padding.
+ }
+ usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n))
+
+ case syscall.SIOCGIFFLAGS:
+ // TODO: Implement. For now, return only that the
+ // device is up so that ifconfig prints it.
+ usermem.ByteOrder.PutUint16(ifr.Data[:2], linux.IFF_UP)
+
+ case syscall.SIOCGIFADDR:
+ // Copy the IPv4 address out.
+ for _, addr := range s.stack.InterfaceAddrs()[index] {
+ // This ioctl is only compatible with AF_INET addresses.
+ if addr.Family != linux.AF_INET {
+ continue
+ }
+ copy(ifr.Data[4:8], addr.Addr)
+ break
+ }
+
+ case syscall.SIOCGIFMETRIC:
+ // Gets the metric of the device. As per netdevice(7), this
+ // always just sets ifr_metric to 0.
+ usermem.ByteOrder.PutUint32(ifr.Data[:4], 0)
+ case syscall.SIOCGIFMTU:
+ // Gets the MTU of the device.
+ // TODO: Implement.
+
+ case syscall.SIOCGIFMAP:
+ // Gets the hardware parameters of the device.
+ // TODO: Implement.
+
+ case syscall.SIOCGIFTXQLEN:
+ // Gets the transmit queue length of the device.
+ // TODO: Implement.
+
+ case syscall.SIOCGIFDSTADDR:
+ // Gets the destination address of a point-to-point device.
+ // TODO: Implement.
+
+ case syscall.SIOCGIFBRDADDR:
+ // Gets the broadcast address of a device.
+ // TODO: Implement.
+
+ case syscall.SIOCGIFNETMASK:
+ // Gets the network mask of a device.
+ // TODO: Implement.
+
+ default:
+ // Not a valid call.
+ return syserr.ErrInvalidArgument
+ }
+
+ return nil
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *SocketOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch arg := int(args[1].Int()); arg {
+ case syscall.SIOCGIFFLAGS,
+ syscall.SIOCGIFADDR,
+ syscall.SIOCGIFBRDADDR,
+ syscall.SIOCGIFDSTADDR,
+ syscall.SIOCGIFHWADDR,
+ syscall.SIOCGIFINDEX,
+ syscall.SIOCGIFMAP,
+ syscall.SIOCGIFMETRIC,
+ syscall.SIOCGIFMTU,
+ syscall.SIOCGIFNETMASK,
+ syscall.SIOCGIFTXQLEN:
+
+ var ifr linux.IFReq
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+ if err := s.interfaceIoctl(ctx, io, arg, &ifr); err != nil {
+ return 0, err.ToError()
+ }
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ case syscall.SIOCGIFCONF:
+ // Return a list of interface addresses or the buffer size
+ // necessary to hold the list.
+ var ifc linux.IFConf
+ if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return 0, err
+ }
+
+ if err := s.ifconfIoctl(ctx, io, &ifc); err != nil {
+ return 0, err
+ }
+
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+
+ return 0, err
+ }
+
+ return Ioctl(ctx, s.Endpoint, io, args)
+}
+
+// ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl.
+func (s *SocketOperations) ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error {
+ // If Ptr is NULL, return the necessary buffer size via Len.
+ // Otherwise, write up to Len bytes starting at Ptr containing ifreq
+ // structs.
+ if ifc.Ptr == 0 {
+ ifc.Len = int32(len(s.stack.Interfaces())) * int32(linux.SizeOfIFReq)
+ return nil
+ }
+
+ max := ifc.Len
+ ifc.Len = 0
+ for key, ifaceAddrs := range s.stack.InterfaceAddrs() {
+ iface := s.stack.Interfaces()[key]
+ for _, ifaceAddr := range ifaceAddrs {
+ // Don't write past the end of the buffer.
+ if ifc.Len+int32(linux.SizeOfIFReq) > max {
+ break
+ }
+ if ifaceAddr.Family != linux.AF_INET {
+ continue
+ }
+
+ // Populate ifr.ifr_addr.
+ ifr := linux.IFReq{}
+ ifr.SetName(iface.Name)
+ usermem.ByteOrder.PutUint16(ifr.Data[0:2], uint16(ifaceAddr.Family))
+ usermem.ByteOrder.PutUint16(ifr.Data[2:4], 0)
+ copy(ifr.Data[4:8], ifaceAddr.Addr[:4])
+
+ // Copy the ifr to userspace.
+ dst := uintptr(ifc.Ptr) + uintptr(ifc.Len)
+ ifc.Len += int32(linux.SizeOfIFReq)
+ if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{
+ AddressSpaceActive: true,
+ }); err != nil {
+ return err
+ }
+ }
+ }
+ return nil
+}
+
+// Ioctl implements fs.FileOperations.Ioctl for sockets backed by a
+// commonEndpoint.
+func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // Switch on ioctl request.
+ switch int(args[1].Int()) {
+ case linux.TIOCINQ:
+ var v tcpip.ReceiveQueueSizeOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return 0, syserr.TranslateNetstackError(err).ToError()
+ }
+
+ if v > math.MaxInt32 {
+ v = math.MaxInt32
+ }
+ // Copy result to user-space.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ case linux.TIOCOUTQ:
+ var v tcpip.SendQueueSizeOption
+ if err := ep.GetSockOpt(&v); err != nil {
+ return 0, syserr.TranslateNetstackError(err).ToError()
+ }
+
+ if v > math.MaxInt32 {
+ v = math.MaxInt32
+ }
+
+ // Copy result to user-space.
+ _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+ }
+
+ return 0, syserror.ENOTTY
+}
diff --git a/pkg/sentry/socket/epsocket/provider.go b/pkg/sentry/socket/epsocket/provider.go
new file mode 100644
index 000000000..5616435b3
--- /dev/null
+++ b/pkg/sentry/socket/epsocket/provider.go
@@ -0,0 +1,113 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package epsocket
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/udp"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// provider is an inet socket provider.
+type provider struct {
+ family int
+ netProto tcpip.NetworkProtocolNumber
+}
+
+// GetTransportProtocol figures out transport protocol. Currently only TCP and
+// UDP are supported.
+func GetTransportProtocol(stype unix.SockType, protocol int) (tcpip.TransportProtocolNumber, *syserr.Error) {
+ switch stype {
+ case linux.SOCK_STREAM:
+ if protocol != 0 && protocol != syscall.IPPROTO_TCP {
+ return 0, syserr.ErrInvalidArgument
+ }
+ return tcp.ProtocolNumber, nil
+
+ case linux.SOCK_DGRAM:
+ if protocol != 0 && protocol != syscall.IPPROTO_UDP {
+ return 0, syserr.ErrInvalidArgument
+ }
+ return udp.ProtocolNumber, nil
+
+ default:
+ return 0, syserr.ErrInvalidArgument
+ }
+}
+
+// Socket creates a new socket object for the AF_INET or AF_INET6 family.
+func (p *provider) Socket(t *kernel.Task, stype unix.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Fail right away if we don't have a stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ // Don't propagate an error here. Instead, allow the socket
+ // code to continue searching for another provider.
+ return nil, nil
+ }
+ eps, ok := stack.(*Stack)
+ if !ok {
+ return nil, nil
+ }
+
+ // Figure out the transport protocol.
+ transProto, err := GetTransportProtocol(stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+
+ // Create the endpoint.
+ wq := &waiter.Queue{}
+ ep, e := eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+ if e != nil {
+ return nil, syserr.TranslateNetstackError(e)
+ }
+
+ return New(t, p.family, stype, wq, ep), nil
+}
+
+// Pair just returns nil sockets (not supported).
+func (*provider) Pair(*kernel.Task, unix.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
+ return nil, nil, nil
+}
+
+// init registers socket providers for AF_INET and AF_INET6.
+func init() {
+ // Providers backed by netstack.
+ p := []provider{
+ {
+ family: linux.AF_INET,
+ netProto: ipv4.ProtocolNumber,
+ },
+ {
+ family: linux.AF_INET6,
+ netProto: ipv6.ProtocolNumber,
+ },
+ }
+
+ for i := range p {
+ socket.RegisterProvider(p[i].family, &p[i])
+ }
+}
diff --git a/pkg/sentry/socket/epsocket/save_restore.go b/pkg/sentry/socket/epsocket/save_restore.go
new file mode 100644
index 000000000..2613f90de
--- /dev/null
+++ b/pkg/sentry/socket/epsocket/save_restore.go
@@ -0,0 +1,27 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package epsocket
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+)
+
+// afterLoad is invoked by stateify.
+func (s *Stack) afterLoad() {
+ s.Stack = stack.StackFromEnv // FIXME
+ if s.Stack == nil {
+ panic("can't restore without netstack/tcpip/stack.Stack")
+ }
+}
diff --git a/pkg/sentry/socket/epsocket/stack.go b/pkg/sentry/socket/epsocket/stack.go
new file mode 100644
index 000000000..ec1d96ccb
--- /dev/null
+++ b/pkg/sentry/socket/epsocket/stack.go
@@ -0,0 +1,132 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package epsocket
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/inet"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv4"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/network/ipv6"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/stack"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/tcp"
+)
+
+// Stack implements inet.Stack for netstack/tcpip/stack.Stack.
+type Stack struct {
+ Stack *stack.Stack `state:"manual"`
+}
+
+// SupportsIPv6 implements Stack.SupportsIPv6.
+func (s *Stack) SupportsIPv6() bool {
+ return s.Stack.CheckNetworkProtocol(ipv6.ProtocolNumber)
+}
+
+// Interfaces implements inet.Stack.Interfaces.
+func (s *Stack) Interfaces() map[int32]inet.Interface {
+ is := make(map[int32]inet.Interface)
+ for id, ni := range s.Stack.NICInfo() {
+ is[int32(id)] = inet.Interface{
+ Name: ni.Name,
+ Addr: []byte(ni.LinkAddress),
+ // TODO: Other fields.
+ }
+ }
+ return is
+}
+
+// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
+func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
+ nicAddrs := make(map[int32][]inet.InterfaceAddr)
+ for id, ni := range s.Stack.NICInfo() {
+ var addrs []inet.InterfaceAddr
+ for _, a := range ni.ProtocolAddresses {
+ var family uint8
+ switch a.Protocol {
+ case ipv4.ProtocolNumber:
+ family = linux.AF_INET
+ case ipv6.ProtocolNumber:
+ family = linux.AF_INET6
+ default:
+ log.Warningf("Unknown network protocol in %+v", a)
+ continue
+ }
+
+ addrs = append(addrs, inet.InterfaceAddr{
+ Family: family,
+ PrefixLen: uint8(len(a.Address) * 8),
+ Addr: []byte(a.Address),
+ // TODO: Other fields.
+ })
+ }
+ nicAddrs[int32(id)] = addrs
+ }
+ return nicAddrs
+}
+
+// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
+func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
+ var rs tcp.ReceiveBufferSizeOption
+ err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &rs)
+ return inet.TCPBufferSize{
+ Min: rs.Min,
+ Default: rs.Default,
+ Max: rs.Max,
+ }, syserr.TranslateNetstackError(err).ToError()
+}
+
+// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
+func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
+ rs := tcp.ReceiveBufferSizeOption{
+ Min: size.Min,
+ Default: size.Default,
+ Max: size.Max,
+ }
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, rs)).ToError()
+}
+
+// TCPSendBufferSize implements inet.Stack.TCPSendBufferSize.
+func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
+ var ss tcp.SendBufferSizeOption
+ err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &ss)
+ return inet.TCPBufferSize{
+ Min: ss.Min,
+ Default: ss.Default,
+ Max: ss.Max,
+ }, syserr.TranslateNetstackError(err).ToError()
+}
+
+// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
+func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
+ ss := tcp.SendBufferSizeOption{
+ Min: size.Min,
+ Default: size.Default,
+ Max: size.Max,
+ }
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, ss)).ToError()
+}
+
+// TCPSACKEnabled implements inet.Stack.TCPSACKEnabled.
+func (s *Stack) TCPSACKEnabled() (bool, error) {
+ var sack tcp.SACKEnabled
+ err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &sack)
+ return bool(sack), syserr.TranslateNetstackError(err).ToError()
+}
+
+// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
+func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError()
+}
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
new file mode 100644
index 000000000..60ec265ba
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -0,0 +1,53 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "hostinet_state",
+ srcs = [
+ "save_restore.go",
+ "socket.go",
+ "stack.go",
+ ],
+ out = "hostinet_autogen_state.go",
+ package = "hostinet",
+)
+
+go_library(
+ name = "hostinet",
+ srcs = [
+ "device.go",
+ "hostinet.go",
+ "hostinet_autogen_state.go",
+ "save_restore.go",
+ "socket.go",
+ "socket_unsafe.go",
+ "stack.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/hostinet",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/log",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/context",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/kdefs",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/safemem",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/usermem",
+ "//pkg/state",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip/transport/unix",
+ "//pkg/waiter",
+ "//pkg/waiter/fdnotifier",
+ ],
+)
diff --git a/pkg/sentry/socket/hostinet/device.go b/pkg/sentry/socket/hostinet/device.go
new file mode 100644
index 000000000..a9a673316
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/device.go
@@ -0,0 +1,19 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+
+var socketDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/hostinet/hostinet.go b/pkg/sentry/socket/hostinet/hostinet.go
new file mode 100644
index 000000000..67c6c8066
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/hostinet.go
@@ -0,0 +1,17 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package hostinet implements AF_INET and AF_INET6 sockets using the host's
+// network stack.
+package hostinet
diff --git a/pkg/sentry/socket/hostinet/save_restore.go b/pkg/sentry/socket/hostinet/save_restore.go
new file mode 100644
index 000000000..0821a794a
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/save_restore.go
@@ -0,0 +1,20 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+// beforeSave is invoked by stateify.
+func (*socketOperations) beforeSave() {
+ panic("host.socketOperations is not savable")
+}
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
new file mode 100644
index 000000000..defa3db2c
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -0,0 +1,562 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "fmt"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+ "gvisor.googlesource.com/gvisor/pkg/waiter/fdnotifier"
+)
+
+const (
+ sizeofInt32 = 4
+
+ // sizeofSockaddr is the size in bytes of the largest sockaddr type
+ // supported by this package.
+ sizeofSockaddr = syscall.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in)
+)
+
+// socketOperations implements fs.FileOperations and socket.Socket for a socket
+// implemented using a host socket.
+type socketOperations struct {
+ socket.ReceiveTimeout
+ fsutil.PipeSeek `state:"nosave"`
+ fsutil.NotDirReaddir `state:"nosave"`
+ fsutil.NoFsync `state:"nosave"`
+ fsutil.NoopFlush `state:"nosave"`
+ fsutil.NoMMap `state:"nosave"`
+
+ fd int // must be O_NONBLOCK
+ queue waiter.Queue
+}
+
+func newSocketFile(ctx context.Context, fd int, nonblock bool) (*fs.File, *syserr.Error) {
+ s := &socketOperations{fd: fd}
+ if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ dirent := socket.NewDirent(ctx, socketDevice)
+ return fs.NewFile(ctx, dirent, fs.FileFlags{NonBlocking: nonblock, Read: true, Write: true}, s), nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *socketOperations) Release() {
+ fdnotifier.RemoveFD(int32(s.fd))
+ syscall.Close(s.fd)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *socketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return fdnotifier.NonBlockingPoll(int32(s.fd), mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *socketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.queue.EventRegister(e, mask)
+ fdnotifier.UpdateFD(int32(s.fd))
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *socketOperations) EventUnregister(e *waiter.Entry) {
+ s.queue.EventUnregister(e)
+ fdnotifier.UpdateFD(int32(s.fd))
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ n, err := dst.CopyOutFrom(ctx, safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of dst.Addrs was unusable.
+ if uint64(dst.NumBytes()) != dsts.NumBytes() {
+ return 0, nil
+ }
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+ if dsts.NumBlocks() == 1 {
+ // Skip allocating []syscall.Iovec.
+ n, err := syscall.Read(s.fd, dsts.Head().ToSlice())
+ if err != nil {
+ return 0, translateIOSyscallError(err)
+ }
+ return uint64(n), nil
+ }
+ return readv(s.fd, iovecsFromBlockSeq(dsts))
+ }))
+ return int64(n), err
+}
+
+// Write implements fs.FileOperations.Write.
+func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ n, err := src.CopyInTo(ctx, safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of src.Addrs was unusable.
+ if uint64(src.NumBytes()) != srcs.NumBytes() {
+ return 0, nil
+ }
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+ if srcs.NumBlocks() == 1 {
+ // Skip allocating []syscall.Iovec.
+ n, err := syscall.Write(s.fd, srcs.Head().ToSlice())
+ if err != nil {
+ return 0, translateIOSyscallError(err)
+ }
+ return uint64(n), nil
+ }
+ return writev(s.fd, iovecsFromBlockSeq(srcs))
+ }))
+ return int64(n), err
+}
+
+// Connect implements socket.Socket.Connect.
+func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ if len(sockaddr) > sizeofSockaddr {
+ sockaddr = sockaddr[:sizeofSockaddr]
+ }
+
+ _, _, errno := syscall.Syscall(syscall.SYS_CONNECT, uintptr(s.fd), uintptr(firstBytePtr(sockaddr)), uintptr(len(sockaddr)))
+
+ if errno == 0 {
+ return nil
+ }
+ if errno != syscall.EINPROGRESS || !blocking {
+ return syserr.FromError(translateIOSyscallError(errno))
+ }
+
+ // "EINPROGRESS: The socket is nonblocking and the connection cannot be
+ // completed immediately. It is possible to select(2) or poll(2) for
+ // completion by selecting the socket for writing. After select(2)
+ // indicates writability, use getsockopt(2) to read the SO_ERROR option at
+ // level SOL-SOCKET to determine whether connect() completed successfully
+ // (SO_ERROR is zero) or unsuccessfully (SO_ERROR is one of the usual error
+ // codes listed here, explaining the reason for the failure)." - connect(2)
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+ if s.Readiness(waiter.EventOut)&waiter.EventOut == 0 {
+ if err := t.Block(ch); err != nil {
+ return syserr.FromError(err)
+ }
+ }
+ val, err := syscall.GetsockoptInt(s.fd, syscall.SOL_SOCKET, syscall.SO_ERROR)
+ if err != nil {
+ return syserr.FromError(err)
+ }
+ if val != 0 {
+ return syserr.FromError(syscall.Errno(uintptr(val)))
+ }
+ return nil
+}
+
+// Accept implements socket.Socket.Accept.
+func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+ var peerAddr []byte
+ var peerAddrlen uint32
+ var peerAddrPtr *byte
+ var peerAddrlenPtr *uint32
+ if peerRequested {
+ peerAddr = make([]byte, sizeofSockaddr)
+ peerAddrlen = uint32(len(peerAddr))
+ peerAddrPtr = &peerAddr[0]
+ peerAddrlenPtr = &peerAddrlen
+ }
+
+ // Conservatively ignore all flags specified by the application and add
+ // SOCK_NONBLOCK since socketOperations requires it.
+ fd, syscallErr := accept4(s.fd, peerAddrPtr, peerAddrlenPtr, syscall.SOCK_NONBLOCK)
+ if blocking {
+ var ch chan struct{}
+ for syscallErr == syserror.ErrWouldBlock {
+ if ch != nil {
+ if syscallErr = t.Block(ch); syscallErr != nil {
+ break
+ }
+ } else {
+ var e waiter.Entry
+ e, ch = waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+ }
+ fd, syscallErr = accept4(s.fd, peerAddrPtr, peerAddrlenPtr, syscall.SOCK_NONBLOCK)
+ }
+ }
+
+ if peerRequested {
+ peerAddr = peerAddr[:peerAddrlen]
+ }
+ if syscallErr != nil {
+ return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr)
+ }
+
+ f, err := newSocketFile(t, fd, flags&syscall.SOCK_NONBLOCK != 0)
+ if err != nil {
+ syscall.Close(fd)
+ return 0, nil, 0, err
+ }
+ defer f.DecRef()
+
+ fdFlags := kernel.FDFlags{
+ CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0,
+ }
+ kfd, kerr := t.FDMap().NewFDFrom(0, f, fdFlags, t.ThreadGroup().Limits())
+ return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr)
+}
+
+// Bind implements socket.Socket.Bind.
+func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ if len(sockaddr) > sizeofSockaddr {
+ sockaddr = sockaddr[:sizeofSockaddr]
+ }
+
+ _, _, errno := syscall.Syscall(syscall.SYS_BIND, uintptr(s.fd), uintptr(firstBytePtr(sockaddr)), uintptr(len(sockaddr)))
+ if errno != 0 {
+ return syserr.FromError(errno)
+ }
+ return nil
+}
+
+// Listen implements socket.Socket.Listen.
+func (s *socketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ return syserr.FromError(syscall.Listen(s.fd, backlog))
+}
+
+// Shutdown implements socket.Socket.Shutdown.
+func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ switch how {
+ case syscall.SHUT_RD, syscall.SHUT_WR, syscall.SHUT_RDWR:
+ return syserr.FromError(syscall.Shutdown(s.fd, how))
+ default:
+ return syserr.ErrInvalidArgument
+ }
+}
+
+// GetSockOpt implements socket.Socket.GetSockOpt.
+func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) {
+ if outLen < 0 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Whitelist options and constrain option length.
+ var optlen int
+ switch level {
+ case syscall.SOL_IPV6:
+ switch name {
+ case syscall.IPV6_V6ONLY:
+ optlen = sizeofInt32
+ }
+ case syscall.SOL_SOCKET:
+ switch name {
+ case syscall.SO_ERROR, syscall.SO_KEEPALIVE, syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR, syscall.SO_TYPE:
+ optlen = sizeofInt32
+ case syscall.SO_LINGER:
+ optlen = syscall.SizeofLinger
+ }
+ case syscall.SOL_TCP:
+ switch name {
+ case syscall.TCP_NODELAY:
+ optlen = sizeofInt32
+ case syscall.TCP_INFO:
+ optlen = int(linux.SizeOfTCPInfo)
+ }
+ }
+ if optlen == 0 {
+ return nil, syserr.ErrProtocolNotAvailable // ENOPROTOOPT
+ }
+ if outLen < optlen {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ opt, err := getsockopt(s.fd, level, name, optlen)
+ if err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return opt, nil
+}
+
+// SetSockOpt implements socket.Socket.SetSockOpt.
+func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
+ // Whitelist options and constrain option length.
+ var optlen int
+ switch level {
+ case syscall.SOL_IPV6:
+ switch name {
+ case syscall.IPV6_V6ONLY:
+ optlen = sizeofInt32
+ }
+ case syscall.SOL_SOCKET:
+ switch name {
+ case syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR:
+ optlen = sizeofInt32
+ }
+ case syscall.SOL_TCP:
+ switch name {
+ case syscall.TCP_NODELAY:
+ optlen = sizeofInt32
+ }
+ }
+ if optlen == 0 {
+ // Pretend to accept socket options we don't understand. This seems
+ // dangerous, but it's what netstack does...
+ return nil
+ }
+ if len(opt) < optlen {
+ return syserr.ErrInvalidArgument
+ }
+ opt = opt[:optlen]
+
+ _, _, errno := syscall.Syscall6(syscall.SYS_SETSOCKOPT, uintptr(s.fd), uintptr(level), uintptr(name), uintptr(firstBytePtr(opt)), uintptr(len(opt)), 0)
+ if errno != 0 {
+ return syserr.FromError(errno)
+ }
+ return nil
+}
+
+// RecvMsg implements socket.Socket.RecvMsg.
+func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, interface{}, uint32, unix.ControlMessages, *syserr.Error) {
+ // Whitelist flags.
+ //
+ // FIXME: We can't support MSG_ERRQUEUE because it uses ancillary
+ // messages that netstack/tcpip/transport/unix doesn't understand. Kill the
+ // Socket interface's dependence on netstack.
+ if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 {
+ return 0, nil, 0, unix.ControlMessages{}, syserr.ErrInvalidArgument
+ }
+
+ var senderAddr []byte
+ if senderRequested {
+ senderAddr = make([]byte, sizeofSockaddr)
+ }
+
+ recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of dst.Addrs was unusable.
+ if uint64(dst.NumBytes()) != dsts.NumBytes() {
+ return 0, nil
+ }
+ if dsts.IsEmpty() {
+ return 0, nil
+ }
+
+ // We always do a non-blocking recv*().
+ sysflags := flags | syscall.MSG_DONTWAIT
+
+ if dsts.NumBlocks() == 1 {
+ // Skip allocating []syscall.Iovec.
+ return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddr)
+ }
+
+ iovs := iovecsFromBlockSeq(dsts)
+ msg := syscall.Msghdr{
+ Iov: &iovs[0],
+ Iovlen: uint64(len(iovs)),
+ }
+ if len(senderAddr) != 0 {
+ msg.Name = &senderAddr[0]
+ msg.Namelen = uint32(len(senderAddr))
+ }
+ n, err := recvmsg(s.fd, &msg, sysflags)
+ if err != nil {
+ return 0, err
+ }
+ senderAddr = senderAddr[:msg.Namelen]
+ return n, nil
+ })
+
+ var ch chan struct{}
+ n, err := dst.CopyOutFrom(t, recvmsgToBlocks)
+ if flags&syscall.MSG_DONTWAIT == 0 {
+ for err == syserror.ErrWouldBlock {
+ // We only expect blocking to come from the actual syscall, in which
+ // case it can't have returned any data.
+ if n != 0 {
+ panic(fmt.Sprintf("CopyOutFrom: got (%d, %v), wanted (0, %v)", n, err, err))
+ }
+ if ch != nil {
+ if err = t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ break
+ }
+ } else {
+ var e waiter.Entry
+ e, ch = waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+ }
+ n, err = dst.CopyOutFrom(t, recvmsgToBlocks)
+ }
+ }
+
+ return int(n), senderAddr, uint32(len(senderAddr)), unix.ControlMessages{}, syserr.FromError(err)
+}
+
+// SendMsg implements socket.Socket.SendMsg.
+func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages unix.ControlMessages) (int, *syserr.Error) {
+ // Whitelist flags.
+ if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) {
+ // Refuse to do anything if any part of src.Addrs was unusable.
+ if uint64(src.NumBytes()) != srcs.NumBytes() {
+ return 0, nil
+ }
+ if srcs.IsEmpty() {
+ return 0, nil
+ }
+
+ // We always do a non-blocking send*().
+ sysflags := flags | syscall.MSG_DONTWAIT
+
+ if srcs.NumBlocks() == 1 {
+ // Skip allocating []syscall.Iovec.
+ src := srcs.Head()
+ n, _, errno := syscall.Syscall6(syscall.SYS_SENDTO, uintptr(s.fd), src.Addr(), uintptr(src.Len()), uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to)))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+ }
+
+ iovs := iovecsFromBlockSeq(srcs)
+ msg := syscall.Msghdr{
+ Iov: &iovs[0],
+ Iovlen: uint64(len(iovs)),
+ }
+ if len(to) != 0 {
+ msg.Name = &to[0]
+ msg.Namelen = uint32(len(to))
+ }
+ return sendmsg(s.fd, &msg, sysflags)
+ })
+
+ var ch chan struct{}
+ n, err := src.CopyInTo(t, sendmsgFromBlocks)
+ if flags&syscall.MSG_DONTWAIT == 0 {
+ for err == syserror.ErrWouldBlock {
+ // We only expect blocking to come from the actual syscall, in which
+ // case it can't have returned any data.
+ if n != 0 {
+ panic(fmt.Sprintf("CopyInTo: got (%d, %v), wanted (0, %v)", n, err, err))
+ }
+ if ch != nil {
+ if err = t.Block(ch); err != nil {
+ break
+ }
+ } else {
+ var e waiter.Entry
+ e, ch = waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+ }
+ n, err = src.CopyInTo(t, sendmsgFromBlocks)
+ }
+ }
+
+ return int(n), syserr.FromError(err)
+}
+
+func iovecsFromBlockSeq(bs safemem.BlockSeq) []syscall.Iovec {
+ iovs := make([]syscall.Iovec, 0, bs.NumBlocks())
+ for ; !bs.IsEmpty(); bs = bs.Tail() {
+ b := bs.Head()
+ iovs = append(iovs, syscall.Iovec{
+ Base: &b.ToSlice()[0],
+ Len: uint64(b.Len()),
+ })
+ // We don't need to care about b.NeedSafecopy(), because the host
+ // kernel will handle such address ranges just fine (by returning
+ // EFAULT).
+ }
+ return iovs
+}
+
+func translateIOSyscallError(err error) error {
+ if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK {
+ return syserror.ErrWouldBlock
+ }
+ return err
+}
+
+type socketProvider struct {
+ family int
+}
+
+// Socket implements socket.Provider.Socket.
+func (p *socketProvider) Socket(t *kernel.Task, stypeflags unix.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Check that we are using the host network stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ return nil, nil
+ }
+ if _, ok := stack.(*Stack); !ok {
+ return nil, nil
+ }
+
+ // Only accept TCP and UDP.
+ stype := int(stypeflags) & linux.SOCK_TYPE_MASK
+ switch stype {
+ case syscall.SOCK_STREAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_TCP:
+ // ok
+ default:
+ return nil, nil
+ }
+ case syscall.SOCK_DGRAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_UDP:
+ // ok
+ default:
+ return nil, nil
+ }
+ default:
+ return nil, nil
+ }
+
+ // Conservatively ignore all flags specified by the application and add
+ // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0
+ // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent.
+ fd, err := syscall.Socket(p.family, stype|syscall.SOCK_NONBLOCK, 0)
+ if err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return newSocketFile(t, fd, stypeflags&syscall.SOCK_NONBLOCK != 0)
+}
+
+// Pair implements socket.Provider.Pair.
+func (p *socketProvider) Pair(t *kernel.Task, stype unix.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+ // Not supported by AF_INET/AF_INET6.
+ return nil, nil, nil
+}
+
+func init() {
+ for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
+ socket.RegisterProvider(family, &socketProvider{family})
+ }
+}
diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go
new file mode 100644
index 000000000..f8bb75636
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/socket_unsafe.go
@@ -0,0 +1,138 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+func firstBytePtr(bs []byte) unsafe.Pointer {
+ if bs == nil {
+ return nil
+ }
+ return unsafe.Pointer(&bs[0])
+}
+
+// Preconditions: len(dsts) != 0.
+func readv(fd int, dsts []syscall.Iovec) (uint64, error) {
+ n, _, errno := syscall.Syscall(syscall.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&dsts[0])), uintptr(len(dsts)))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+}
+
+// Preconditions: len(srcs) != 0.
+func writev(fd int, srcs []syscall.Iovec) (uint64, error) {
+ n, _, errno := syscall.Syscall(syscall.SYS_WRITEV, uintptr(fd), uintptr(unsafe.Pointer(&srcs[0])), uintptr(len(srcs)))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *socketOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ switch cmd := uintptr(args[1].Int()); cmd {
+ case syscall.TIOCINQ, syscall.TIOCOUTQ:
+ var val int32
+ if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(s.fd), cmd, uintptr(unsafe.Pointer(&val))); errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ var buf [4]byte
+ usermem.ByteOrder.PutUint32(buf[:], uint32(val))
+ _, err := io.CopyOut(ctx, args[2].Pointer(), buf[:], usermem.IOOpts{
+ AddressSpaceActive: true,
+ })
+ return 0, err
+
+ default:
+ return 0, syserror.ENOTTY
+ }
+}
+
+func accept4(fd int, addr *byte, addrlen *uint32, flags int) (int, error) {
+ afd, _, errno := syscall.Syscall6(syscall.SYS_ACCEPT4, uintptr(fd), uintptr(unsafe.Pointer(addr)), uintptr(unsafe.Pointer(addrlen)), uintptr(flags), 0, 0)
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return int(afd), nil
+}
+
+func getsockopt(fd int, level, name int, optlen int) ([]byte, error) {
+ opt := make([]byte, optlen)
+ optlen32 := int32(len(opt))
+ _, _, errno := syscall.Syscall6(syscall.SYS_GETSOCKOPT, uintptr(fd), uintptr(level), uintptr(name), uintptr(firstBytePtr(opt)), uintptr(unsafe.Pointer(&optlen32)), 0)
+ if errno != 0 {
+ return nil, errno
+ }
+ return opt[:optlen32], nil
+}
+
+// GetSockName implements socket.Socket.GetSockName.
+func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ addr := make([]byte, sizeofSockaddr)
+ addrlen := uint32(len(addr))
+ _, _, errno := syscall.Syscall(syscall.SYS_GETSOCKNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen)))
+ if errno != 0 {
+ return nil, 0, syserr.FromError(errno)
+ }
+ return addr[:addrlen], addrlen, nil
+}
+
+// GetPeerName implements socket.Socket.GetPeerName.
+func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ addr := make([]byte, sizeofSockaddr)
+ addrlen := uint32(len(addr))
+ _, _, errno := syscall.Syscall(syscall.SYS_GETPEERNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen)))
+ if errno != 0 {
+ return nil, 0, syserr.FromError(errno)
+ }
+ return addr[:addrlen], addrlen, nil
+}
+
+func recvfrom(fd int, dst []byte, flags int, from *[]byte) (uint64, error) {
+ fromLen := uint32(len(*from))
+ n, _, errno := syscall.Syscall6(syscall.SYS_RECVFROM, uintptr(fd), uintptr(firstBytePtr(dst)), uintptr(len(dst)), uintptr(flags), uintptr(firstBytePtr(*from)), uintptr(unsafe.Pointer(&fromLen)))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ *from = (*from)[:fromLen]
+ return uint64(n), nil
+}
+
+func recvmsg(fd int, msg *syscall.Msghdr, flags int) (uint64, error) {
+ n, _, errno := syscall.Syscall(syscall.SYS_RECVMSG, uintptr(fd), uintptr(unsafe.Pointer(msg)), uintptr(flags))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+}
+
+func sendmsg(fd int, msg *syscall.Msghdr, flags int) (uint64, error) {
+ n, _, errno := syscall.Syscall(syscall.SYS_SENDMSG, uintptr(fd), uintptr(unsafe.Pointer(msg)), uintptr(flags))
+ if errno != 0 {
+ return 0, translateIOSyscallError(errno)
+ }
+ return uint64(n), nil
+}
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
new file mode 100644
index 000000000..44c3b9a3f
--- /dev/null
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -0,0 +1,244 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package hostinet
+
+import (
+ "fmt"
+ "io/ioutil"
+ "os"
+ "strings"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/log"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/inet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+)
+
+var defaultRecvBufSize = inet.TCPBufferSize{
+ Min: 4096,
+ Default: 87380,
+ Max: 6291456,
+}
+
+var defaultSendBufSize = inet.TCPBufferSize{
+ Min: 4096,
+ Default: 16384,
+ Max: 4194304,
+}
+
+// Stack implements inet.Stack for host sockets.
+type Stack struct {
+ // Stack is immutable.
+ interfaces map[int32]inet.Interface
+ interfaceAddrs map[int32][]inet.InterfaceAddr
+ supportsIPv6 bool
+ tcpRecvBufSize inet.TCPBufferSize
+ tcpSendBufSize inet.TCPBufferSize
+ tcpSACKEnabled bool
+}
+
+// NewStack returns an empty Stack containing no configuration.
+func NewStack() *Stack {
+ return &Stack{
+ interfaces: make(map[int32]inet.Interface),
+ interfaceAddrs: make(map[int32][]inet.InterfaceAddr),
+ }
+}
+
+// Configure sets up the stack using the current state of the host network.
+func (s *Stack) Configure() error {
+ if err := addHostInterfaces(s); err != nil {
+ return err
+ }
+
+ if _, err := os.Stat("/proc/net/if_inet6"); err == nil {
+ s.supportsIPv6 = true
+ }
+
+ s.tcpRecvBufSize = defaultRecvBufSize
+ if tcpRMem, err := readTCPBufferSizeFile("/proc/sys/net/ipv4/tcp_rmem"); err == nil {
+ s.tcpRecvBufSize = tcpRMem
+ } else {
+ log.Warningf("Failed to read TCP receive buffer size, using default values")
+ }
+
+ s.tcpSendBufSize = defaultSendBufSize
+ if tcpWMem, err := readTCPBufferSizeFile("/proc/sys/net/ipv4/tcp_wmem"); err == nil {
+ s.tcpSendBufSize = tcpWMem
+ } else {
+ log.Warningf("Failed to read TCP send buffer size, using default values")
+ }
+
+ s.tcpSACKEnabled = false
+ if sack, err := ioutil.ReadFile("/proc/sys/net/ipv4/tcp_sack"); err == nil {
+ s.tcpSACKEnabled = strings.TrimSpace(string(sack)) != "0"
+ } else {
+ log.Warningf("Failed to read if TCP SACK if enabled, setting to false")
+ }
+
+ return nil
+}
+
+// ExtractHostInterfaces will populate an interface map and
+// interfaceAddrs map with the results of the equivalent
+// netlink messages.
+func ExtractHostInterfaces(links []syscall.NetlinkMessage, addrs []syscall.NetlinkMessage, interfaces map[int32]inet.Interface, interfaceAddrs map[int32][]inet.InterfaceAddr) error {
+ for _, link := range links {
+ if link.Header.Type != syscall.RTM_NEWLINK {
+ continue
+ }
+ if len(link.Data) < syscall.SizeofIfInfomsg {
+ return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid data length (%d bytes, expected at least %d bytes)", len(link.Data), syscall.SizeofIfInfomsg)
+ }
+ var ifinfo syscall.IfInfomsg
+ binary.Unmarshal(link.Data[:syscall.SizeofIfInfomsg], usermem.ByteOrder, &ifinfo)
+ inetIF := inet.Interface{
+ DeviceType: ifinfo.Type,
+ Flags: ifinfo.Flags,
+ }
+ // Not clearly documented: syscall.ParseNetlinkRouteAttr will check the
+ // syscall.NetlinkMessage.Header.Type and skip the struct ifinfomsg
+ // accordingly.
+ attrs, err := syscall.ParseNetlinkRouteAttr(&link)
+ if err != nil {
+ return fmt.Errorf("RTM_GETLINK returned RTM_NEWLINK message with invalid rtattrs: %v", err)
+ }
+ for _, attr := range attrs {
+ switch attr.Attr.Type {
+ case syscall.IFLA_ADDRESS:
+ inetIF.Addr = attr.Value
+ case syscall.IFLA_IFNAME:
+ inetIF.Name = string(attr.Value[:len(attr.Value)-1])
+ }
+ }
+ interfaces[ifinfo.Index] = inetIF
+ }
+
+ for _, addr := range addrs {
+ if addr.Header.Type != syscall.RTM_NEWADDR {
+ continue
+ }
+ if len(addr.Data) < syscall.SizeofIfAddrmsg {
+ return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid data length (%d bytes, expected at least %d bytes)", len(addr.Data), syscall.SizeofIfAddrmsg)
+ }
+ var ifaddr syscall.IfAddrmsg
+ binary.Unmarshal(addr.Data[:syscall.SizeofIfAddrmsg], usermem.ByteOrder, &ifaddr)
+ inetAddr := inet.InterfaceAddr{
+ Family: ifaddr.Family,
+ PrefixLen: ifaddr.Prefixlen,
+ Flags: ifaddr.Flags,
+ }
+ attrs, err := syscall.ParseNetlinkRouteAttr(&addr)
+ if err != nil {
+ return fmt.Errorf("RTM_GETADDR returned RTM_NEWADDR message with invalid rtattrs: %v", err)
+ }
+ for _, attr := range attrs {
+ switch attr.Attr.Type {
+ case syscall.IFA_ADDRESS:
+ inetAddr.Addr = attr.Value
+ }
+ }
+ interfaceAddrs[int32(ifaddr.Index)] = append(interfaceAddrs[int32(ifaddr.Index)], inetAddr)
+ }
+
+ return nil
+}
+
+func addHostInterfaces(s *Stack) error {
+ links, err := doNetlinkRouteRequest(syscall.RTM_GETLINK)
+ if err != nil {
+ return fmt.Errorf("RTM_GETLINK failed: %v", err)
+ }
+
+ addrs, err := doNetlinkRouteRequest(syscall.RTM_GETADDR)
+ if err != nil {
+ return fmt.Errorf("RTM_GETADDR failed: %v", err)
+ }
+
+ return ExtractHostInterfaces(links, addrs, s.interfaces, s.interfaceAddrs)
+}
+
+func doNetlinkRouteRequest(req int) ([]syscall.NetlinkMessage, error) {
+ data, err := syscall.NetlinkRIB(req, syscall.AF_UNSPEC)
+ if err != nil {
+ return nil, err
+ }
+ return syscall.ParseNetlinkMessage(data)
+}
+
+func readTCPBufferSizeFile(filename string) (inet.TCPBufferSize, error) {
+ contents, err := ioutil.ReadFile(filename)
+ if err != nil {
+ return inet.TCPBufferSize{}, fmt.Errorf("failed to read %s: %v", filename, err)
+ }
+ ioseq := usermem.BytesIOSequence(contents)
+ fields := make([]int32, 3)
+ if n, err := usermem.CopyInt32StringsInVec(context.Background(), ioseq.IO, ioseq.Addrs, fields, ioseq.Opts); n != ioseq.NumBytes() || err != nil {
+ return inet.TCPBufferSize{}, fmt.Errorf("failed to parse %s (%q): got %v after %d/%d bytes", filename, contents, err, n, ioseq.NumBytes())
+ }
+ return inet.TCPBufferSize{
+ Min: int(fields[0]),
+ Default: int(fields[1]),
+ Max: int(fields[2]),
+ }, nil
+}
+
+// Interfaces implements inet.Stack.Interfaces.
+func (s *Stack) Interfaces() map[int32]inet.Interface {
+ return s.interfaces
+}
+
+// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
+func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
+ return s.interfaceAddrs
+}
+
+// SupportsIPv6 implements inet.Stack.SupportsIPv6.
+func (s *Stack) SupportsIPv6() bool {
+ return s.supportsIPv6
+}
+
+// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
+func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
+ return s.tcpRecvBufSize, nil
+}
+
+// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
+func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
+ return syserror.EACCES
+}
+
+// TCPSendBufferSize implements inet.Stack.TCPSendBufferSize.
+func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
+ return s.tcpSendBufSize, nil
+}
+
+// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
+func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
+ return syserror.EACCES
+}
+
+// TCPSACKEnabled implements inet.Stack.TCPSACKEnabled.
+func (s *Stack) TCPSACKEnabled() (bool, error) {
+ return s.tcpSACKEnabled, nil
+}
+
+// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
+func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
+ return syserror.EACCES
+}
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
new file mode 100644
index 000000000..9df3ab17c
--- /dev/null
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -0,0 +1,47 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "netlink_state",
+ srcs = [
+ "socket.go",
+ ],
+ out = "netlink_state.go",
+ package = "netlink",
+)
+
+go_library(
+ name = "netlink",
+ srcs = [
+ "message.go",
+ "netlink_state.go",
+ "provider.go",
+ "socket.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/netlink",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/context",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/kdefs",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/netlink/port",
+ "//pkg/sentry/socket/unix",
+ "//pkg/sentry/usermem",
+ "//pkg/state",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/tcpip/transport/unix",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go
new file mode 100644
index 000000000..b902d7ec9
--- /dev/null
+++ b/pkg/sentry/socket/netlink/message.go
@@ -0,0 +1,159 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netlink
+
+import (
+ "fmt"
+ "math"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+)
+
+// alignUp rounds a length up to an alignment.
+//
+// Preconditions: align is a power of two.
+func alignUp(length int, align uint) int {
+ return (length + int(align) - 1) &^ (int(align) - 1)
+}
+
+// Message contains a complete serialized netlink message.
+type Message struct {
+ buf []byte
+}
+
+// NewMessage creates a new Message containing the passed header.
+//
+// The header length will be updated by Finalize.
+func NewMessage(hdr linux.NetlinkMessageHeader) *Message {
+ return &Message{
+ buf: binary.Marshal(nil, usermem.ByteOrder, hdr),
+ }
+}
+
+// Finalize returns the []byte containing the entire message, with the total
+// length set in the message header. The Message must not be modified after
+// calling Finalize.
+func (m *Message) Finalize() []byte {
+ // Update length, which is the first 4 bytes of the header.
+ usermem.ByteOrder.PutUint32(m.buf, uint32(len(m.buf)))
+
+ // Align the message. Note that the message length in the header (set
+ // above) is the useful length of the message, not the total aligned
+ // length. See net/netlink/af_netlink.c:__nlmsg_put.
+ aligned := alignUp(len(m.buf), linux.NLMSG_ALIGNTO)
+ m.putZeros(aligned - len(m.buf))
+ return m.buf
+}
+
+// putZeros adds n zeros to the message.
+func (m *Message) putZeros(n int) {
+ for n > 0 {
+ m.buf = append(m.buf, 0)
+ n--
+ }
+}
+
+// Put serializes v into the message.
+func (m *Message) Put(v interface{}) {
+ m.buf = binary.Marshal(m.buf, usermem.ByteOrder, v)
+}
+
+// PutAttr adds v to the message as a netlink attribute.
+//
+// Preconditions: The serialized attribute (linux.NetlinkAttrHeaderSize +
+// binary.Size(v) fits in math.MaxUint16 bytes.
+func (m *Message) PutAttr(atype uint16, v interface{}) {
+ l := linux.NetlinkAttrHeaderSize + int(binary.Size(v))
+ if l > math.MaxUint16 {
+ panic(fmt.Sprintf("attribute too large: %d", l))
+ }
+
+ m.Put(linux.NetlinkAttrHeader{
+ Type: atype,
+ Length: uint16(l),
+ })
+ m.Put(v)
+
+ // Align the attribute.
+ aligned := alignUp(l, linux.NLA_ALIGNTO)
+ m.putZeros(aligned - l)
+}
+
+// PutAttrString adds s to the message as a netlink attribute.
+func (m *Message) PutAttrString(atype uint16, s string) {
+ l := linux.NetlinkAttrHeaderSize + len(s) + 1
+ m.Put(linux.NetlinkAttrHeader{
+ Type: atype,
+ Length: uint16(l),
+ })
+
+ // String + NUL-termination.
+ m.Put([]byte(s))
+ m.putZeros(1)
+
+ // Align the attribute.
+ aligned := alignUp(l, linux.NLA_ALIGNTO)
+ m.putZeros(aligned - l)
+}
+
+// MessageSet contains a series of netlink messages.
+type MessageSet struct {
+ // Multi indicates that this a multi-part message, to be terminated by
+ // NLMSG_DONE. NLMSG_DONE is sent even if the set contains only one
+ // Message.
+ //
+ // If Multi is set, all added messages will have NLM_F_MULTI set.
+ Multi bool
+
+ // PortID is the destination port for all messages.
+ PortID int32
+
+ // Seq is the sequence counter for all messages in the set.
+ Seq uint32
+
+ // Messages contains the messages in the set.
+ Messages []*Message
+}
+
+// NewMessageSet creates a new MessageSet.
+//
+// portID is the destination port to set as PortID in all messages.
+//
+// seq is the sequence counter to set as seq in all messages in the set.
+func NewMessageSet(portID int32, seq uint32) *MessageSet {
+ return &MessageSet{
+ PortID: portID,
+ Seq: seq,
+ }
+}
+
+// AddMessage adds a new message to the set and returns it for further
+// additions.
+//
+// The passed header will have Seq, PortID and the multi flag set
+// automatically.
+func (ms *MessageSet) AddMessage(hdr linux.NetlinkMessageHeader) *Message {
+ hdr.Seq = ms.Seq
+ hdr.PortID = uint32(ms.PortID)
+ if ms.Multi {
+ hdr.Flags |= linux.NLM_F_MULTI
+ }
+
+ m := NewMessage(hdr)
+ ms.Messages = append(ms.Messages, m)
+ return m
+}
diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD
new file mode 100644
index 000000000..7340b95c9
--- /dev/null
+++ b/pkg/sentry/socket/netlink/port/BUILD
@@ -0,0 +1,28 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library", "go_test")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "port_state",
+ srcs = ["port.go"],
+ out = "port_state.go",
+ package = "port",
+)
+
+go_library(
+ name = "port",
+ srcs = [
+ "port.go",
+ "port_state.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/netlink/port",
+ visibility = ["//pkg/sentry:internal"],
+ deps = ["//pkg/state"],
+)
+
+go_test(
+ name = "port_test",
+ srcs = ["port_test.go"],
+ embed = [":port"],
+)
diff --git a/pkg/sentry/socket/netlink/port/port.go b/pkg/sentry/socket/netlink/port/port.go
new file mode 100644
index 000000000..4ccf0b84c
--- /dev/null
+++ b/pkg/sentry/socket/netlink/port/port.go
@@ -0,0 +1,114 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package port provides port ID allocation for netlink sockets.
+//
+// A netlink port is any int32 value. Positive ports are typically equivalent
+// to the PID of the binding process. If that port is unavailable, negative
+// ports are searched to find a free port that will not conflict with other
+// PIDS.
+package port
+
+import (
+ "fmt"
+ "math"
+ "math/rand"
+ "sync"
+)
+
+// maxPorts is a sanity limit on the maximum number of ports to allocate per
+// protocol.
+const maxPorts = 10000
+
+// Manager allocates netlink port IDs.
+type Manager struct {
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // ports contains a map of allocated ports for each protocol.
+ ports map[int]map[int32]struct{}
+}
+
+// New creates a new Manager.
+func New() *Manager {
+ return &Manager{
+ ports: make(map[int]map[int32]struct{}),
+ }
+}
+
+// Allocate reserves a new port ID for protocol. hint will be taken if
+// available.
+func (m *Manager) Allocate(protocol int, hint int32) (int32, bool) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ proto, ok := m.ports[protocol]
+ if !ok {
+ proto = make(map[int32]struct{})
+ // Port 0 is reserved for the kernel.
+ proto[0] = struct{}{}
+ m.ports[protocol] = proto
+ }
+
+ if len(proto) >= maxPorts {
+ return 0, false
+ }
+
+ if _, ok := proto[hint]; !ok {
+ // Hint is available, reserve it.
+ proto[hint] = struct{}{}
+ return hint, true
+ }
+
+ // Search for any free port in [math.MinInt32, -4096). The positive
+ // port space is left open for pid-based allocations. This behavior is
+ // consistent with Linux.
+ start := int32(math.MinInt32 + rand.Int63n(math.MaxInt32-4096+1))
+ curr := start
+ for {
+ if _, ok := proto[curr]; !ok {
+ proto[curr] = struct{}{}
+ return curr, true
+ }
+
+ curr--
+ if curr >= -4096 {
+ curr = -4097
+ }
+ if curr == start {
+ // Nothing found. We should always find a free port
+ // because maxPorts < -4096 - MinInt32.
+ panic(fmt.Sprintf("No free port found in %+v", proto))
+ }
+ }
+}
+
+// Release frees the specified port for protocol.
+//
+// Preconditions: port is already allocated.
+func (m *Manager) Release(protocol int, port int32) {
+ m.mu.Lock()
+ defer m.mu.Unlock()
+
+ proto, ok := m.ports[protocol]
+ if !ok {
+ panic(fmt.Sprintf("Released port %d for protocol %d which has no allocations", port, protocol))
+ }
+
+ if _, ok := proto[port]; !ok {
+ panic(fmt.Sprintf("Released port %d for protocol %d is not allocated", port, protocol))
+ }
+
+ delete(proto, port)
+}
diff --git a/pkg/sentry/socket/netlink/port/port_test.go b/pkg/sentry/socket/netlink/port/port_test.go
new file mode 100644
index 000000000..34565e2f9
--- /dev/null
+++ b/pkg/sentry/socket/netlink/port/port_test.go
@@ -0,0 +1,82 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package port
+
+import (
+ "testing"
+)
+
+func TestAllocateHint(t *testing.T) {
+ m := New()
+
+ // We can get the hint port.
+ p, ok := m.Allocate(0, 1)
+ if !ok {
+ t.Errorf("m.Allocate got !ok want ok")
+ }
+ if p != 1 {
+ t.Errorf("m.Allocate(0, 1) got %d want 1", p)
+ }
+
+ // Hint is taken.
+ p, ok = m.Allocate(0, 1)
+ if !ok {
+ t.Errorf("m.Allocate got !ok want ok")
+ }
+ if p == 1 {
+ t.Errorf("m.Allocate(0, 1) got 1 want anything else")
+ }
+
+ // Hint is available for a different protocol.
+ p, ok = m.Allocate(1, 1)
+ if !ok {
+ t.Errorf("m.Allocate got !ok want ok")
+ }
+ if p != 1 {
+ t.Errorf("m.Allocate(1, 1) got %d want 1", p)
+ }
+
+ m.Release(0, 1)
+
+ // Hint is available again after release.
+ p, ok = m.Allocate(0, 1)
+ if !ok {
+ t.Errorf("m.Allocate got !ok want ok")
+ }
+ if p != 1 {
+ t.Errorf("m.Allocate(0, 1) got %d want 1", p)
+ }
+}
+
+func TestAllocateExhausted(t *testing.T) {
+ m := New()
+
+ // Fill all ports (0 is already reserved).
+ for i := int32(1); i < maxPorts; i++ {
+ p, ok := m.Allocate(0, i)
+ if !ok {
+ t.Fatalf("m.Allocate got !ok want ok")
+ }
+ if p != i {
+ t.Fatalf("m.Allocate(0, %d) got %d want %d", i, p, i)
+ }
+ }
+
+ // Now no more can be allocated.
+ p, ok := m.Allocate(0, 1)
+ if ok {
+ t.Errorf("m.Allocate got %d, ok want !ok", p)
+ }
+}
diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go
new file mode 100644
index 000000000..36800da4d
--- /dev/null
+++ b/pkg/sentry/socket/netlink/provider.go
@@ -0,0 +1,104 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netlink
+
+import (
+ "fmt"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix"
+)
+
+// Protocol is the implementation of a netlink socket protocol.
+type Protocol interface {
+ // Protocol returns the Linux netlink protocol value.
+ Protocol() int
+
+ // ProcessMessage processes a single message from userspace.
+ //
+ // If err == nil, any messages added to ms will be sent back to the
+ // other end of the socket. Setting ms.Multi will cause an NLMSG_DONE
+ // message to be sent even if ms contains no messages.
+ ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *MessageSet) *syserr.Error
+}
+
+// Provider is a function that creates a new Protocol for a specific netlink
+// protocol.
+//
+// Note that this is distinct from socket.Provider, which is used for all
+// socket families.
+type Provider func(t *kernel.Task) (Protocol, *syserr.Error)
+
+// protocols holds a map of all known address protocols and their provider.
+var protocols = make(map[int]Provider)
+
+// RegisterProvider registers the provider of a given address protocol so that
+// netlink sockets of that type can be created via socket(2).
+//
+// Preconditions: May only be called before any netlink sockets are created.
+func RegisterProvider(protocol int, provider Provider) {
+ if p, ok := protocols[protocol]; ok {
+ panic(fmt.Sprintf("Netlink protocol %d already provided by %+v", protocol, p))
+ }
+
+ protocols[protocol] = provider
+}
+
+// socketProvider implements socket.Provider.
+type socketProvider struct {
+}
+
+// Socket implements socket.Provider.Socket.
+func (*socketProvider) Socket(t *kernel.Task, stype unix.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Netlink sockets must be specified as datagram or raw, but they
+ // behave the same regardless of type.
+ if stype != unix.SockDgram && stype != unix.SockRaw {
+ return nil, syserr.ErrSocketNotSupported
+ }
+
+ provider, ok := protocols[protocol]
+ if !ok {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ p, err := provider(t)
+ if err != nil {
+ return nil, err
+ }
+
+ s, err := NewSocket(t, p)
+ if err != nil {
+ return nil, err
+ }
+
+ d := socket.NewDirent(t, netlinkSocketDevice)
+ return fs.NewFile(t, d, fs.FileFlags{Read: true, Write: true}, s), nil
+}
+
+// Pair implements socket.Provider.Pair by returning an error.
+func (*socketProvider) Pair(*kernel.Task, unix.SockType, int) (*fs.File, *fs.File, *syserr.Error) {
+ // Netlink sockets never supports creating socket pairs.
+ return nil, nil, syserr.ErrNotSupported
+}
+
+// init registers the socket provider.
+func init() {
+ socket.RegisterProvider(linux.AF_NETLINK, &socketProvider{})
+}
diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD
new file mode 100644
index 000000000..ff3f7b7a4
--- /dev/null
+++ b/pkg/sentry/socket/netlink/route/BUILD
@@ -0,0 +1,33 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "route_state",
+ srcs = ["protocol.go"],
+ out = "route_state.go",
+ package = "route",
+)
+
+go_library(
+ name = "route",
+ srcs = [
+ "protocol.go",
+ "route_state.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/netlink/route",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/sentry/context",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/auth",
+ "//pkg/sentry/socket/netlink",
+ "//pkg/sentry/usermem",
+ "//pkg/state",
+ "//pkg/syserr",
+ ],
+)
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
new file mode 100644
index 000000000..d611519d4
--- /dev/null
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -0,0 +1,189 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package route provides a NETLINK_ROUTE socket protocol.
+package route
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/inet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/auth"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/netlink"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+)
+
+// commandKind describes the operational class of a message type.
+//
+// The route message types use the lower 2 bits of the type to describe class
+// of command.
+type commandKind int
+
+const (
+ kindNew commandKind = 0x0
+ kindDel = 0x1
+ kindGet = 0x2
+ kindSet = 0x3
+)
+
+func typeKind(typ uint16) commandKind {
+ return commandKind(typ & 0x3)
+}
+
+// Protocol implements netlink.Protocol.
+type Protocol struct {
+ // stack is the network stack that this provider describes.
+ //
+ // May be nil.
+ stack inet.Stack
+}
+
+var _ netlink.Protocol = (*Protocol)(nil)
+
+// NewProtocol creates a NETLINK_ROUTE netlink.Protocol.
+func NewProtocol(t *kernel.Task) (netlink.Protocol, *syserr.Error) {
+ return &Protocol{
+ stack: t.NetworkContext(),
+ }, nil
+}
+
+// Protocol implements netlink.Protocol.Protocol.
+func (p *Protocol) Protocol() int {
+ return linux.NETLINK_ROUTE
+}
+
+// dumpLinks handles RTM_GETLINK + NLM_F_DUMP requests.
+func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+ // NLM_F_DUMP + RTM_GETLINK messages are supposed to include an
+ // ifinfomsg. However, Linux <3.9 only checked for rtgenmsg, and some
+ // userspace applications (including glibc) still include rtgenmsg.
+ // Linux has a workaround based on the total message length.
+ //
+ // We don't bother to check for either, since we don't support any
+ // extra attributes that may be included anyways.
+ //
+ // The message may also contain netlink attribute IFLA_EXT_MASK, which
+ // we don't support.
+
+ // The RTM_GETLINK dump response is a set of messages each containing
+ // an InterfaceInfoMessage followed by a set of netlink attributes.
+
+ // We always send back an NLMSG_DONE.
+ ms.Multi = true
+
+ if p.stack == nil {
+ // No network devices.
+ return nil
+ }
+
+ for id, i := range p.stack.Interfaces() {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.RTM_NEWLINK,
+ })
+
+ m.Put(linux.InterfaceInfoMessage{
+ Family: linux.AF_UNSPEC,
+ Type: i.DeviceType,
+ Index: id,
+ Flags: i.Flags,
+ })
+
+ m.PutAttrString(linux.IFLA_IFNAME, i.Name)
+
+ // TODO: There are many more attributes, such as
+ // MAC address.
+ }
+
+ return nil
+}
+
+// dumpAddrs handles RTM_GETADDR + NLM_F_DUMP requests.
+func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+ // RTM_GETADDR dump requests need not contain anything more than the
+ // netlink header and 1 byte protocol family common to all
+ // NETLINK_ROUTE requests.
+ //
+ // TODO: Filter output by passed protocol family.
+
+ // The RTM_GETADDR dump response is a set of RTM_NEWADDR messages each
+ // containing an InterfaceAddrMessage followed by a set of netlink
+ // attributes.
+
+ // We always send back an NLMSG_DONE.
+ ms.Multi = true
+
+ if p.stack == nil {
+ // No network devices.
+ return nil
+ }
+
+ for id, as := range p.stack.InterfaceAddrs() {
+ for _, a := range as {
+ m := ms.AddMessage(linux.NetlinkMessageHeader{
+ Type: linux.RTM_NEWADDR,
+ })
+
+ m.Put(linux.InterfaceAddrMessage{
+ Family: a.Family,
+ PrefixLen: a.PrefixLen,
+ Index: uint32(id),
+ })
+
+ m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr))
+
+ // TODO: There are many more attributes.
+ }
+ }
+
+ return nil
+}
+
+// ProcessMessage implements netlink.Protocol.ProcessMessage.
+func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error {
+ // All messages start with a 1 byte protocol family.
+ if len(data) < 1 {
+ // Linux ignores messages missing the protocol family. See
+ // net/core/rtnetlink.c:rtnetlink_rcv_msg.
+ return nil
+ }
+
+ // Non-GET message types require CAP_NET_ADMIN.
+ if typeKind(hdr.Type) != kindGet {
+ creds := auth.CredentialsFromContext(ctx)
+ if !creds.HasCapability(linux.CAP_NET_ADMIN) {
+ return syserr.ErrPermissionDenied
+ }
+ }
+
+ // TODO: Only the dump variant of the types below are
+ // supported.
+ if hdr.Flags&linux.NLM_F_DUMP != linux.NLM_F_DUMP {
+ return syserr.ErrNotSupported
+ }
+
+ switch hdr.Type {
+ case linux.RTM_GETLINK:
+ return p.dumpLinks(ctx, hdr, data, ms)
+ case linux.RTM_GETADDR:
+ return p.dumpAddrs(ctx, hdr, data, ms)
+ default:
+ return syserr.ErrNotSupported
+ }
+}
+
+// init registers the NETLINK_ROUTE provider.
+func init() {
+ netlink.RegisterProvider(linux.NETLINK_ROUTE, NewProtocol)
+}
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
new file mode 100644
index 000000000..2d0e59ceb
--- /dev/null
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -0,0 +1,517 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package netlink provides core functionality for netlink sockets.
+package netlink
+
+import (
+ "sync"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/netlink/port"
+ sunix "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// defaultSendBufferSize is the default size for the send buffer.
+const defaultSendBufferSize = 16 * 1024
+
+// netlinkSocketDevice is the netlink socket virtual device.
+var netlinkSocketDevice = device.NewAnonDevice()
+
+// Socket is the base socket type for netlink sockets.
+//
+// This implementation only supports userspace sending and receiving messages
+// to/from the kernel.
+//
+// Socket implements socket.Socket.
+type Socket struct {
+ socket.ReceiveTimeout
+ fsutil.PipeSeek `state:"nosave"`
+ fsutil.NotDirReaddir `state:"nosave"`
+ fsutil.NoFsync `state:"nosave"`
+ fsutil.NoopFlush `state:"nosave"`
+ fsutil.NoMMap `state:"nosave"`
+
+ // ports provides netlink port allocation.
+ ports *port.Manager
+
+ // protocol is the netlink protocol implementation.
+ protocol Protocol
+
+ // ep is a datagram unix endpoint used to buffer messages sent from the
+ // kernel to userspace. RecvMsg reads messages from this endpoint.
+ ep unix.Endpoint
+
+ // connection is the kernel's connection to ep, used to write messages
+ // sent to userspace.
+ connection unix.ConnectedEndpoint
+
+ // mu protects the fields below.
+ mu sync.Mutex `state:"nosave"`
+
+ // bound indicates that portid is valid.
+ bound bool
+
+ // portID is the port ID allocated for this socket.
+ portID int32
+
+ // sendBufferSize is the send buffer "size". We don't actually have a
+ // fixed buffer but only consume this many bytes.
+ sendBufferSize uint64
+}
+
+var _ socket.Socket = (*Socket)(nil)
+
+// NewSocket creates a new Socket.
+func NewSocket(t *kernel.Task, protocol Protocol) (*Socket, *syserr.Error) {
+ // Datagram endpoint used to buffer kernel -> user messages.
+ ep := unix.NewConnectionless()
+
+ // Bind the endpoint for good measure so we can connect to it. The
+ // bound address will never be exposed.
+ if terr := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); terr != nil {
+ ep.Close()
+ return nil, syserr.TranslateNetstackError(terr)
+ }
+
+ // Create a connection from which the kernel can write messages.
+ connection, terr := ep.(unix.BoundEndpoint).UnidirectionalConnect()
+ if terr != nil {
+ ep.Close()
+ return nil, syserr.TranslateNetstackError(terr)
+ }
+
+ return &Socket{
+ ports: t.Kernel().NetlinkPorts(),
+ protocol: protocol,
+ ep: ep,
+ connection: connection,
+ sendBufferSize: defaultSendBufferSize,
+ }, nil
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *Socket) Release() {
+ s.connection.Release()
+ s.ep.Close()
+
+ if s.bound {
+ s.ports.Release(s.protocol.Protocol(), s.portID)
+ }
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *Socket) Readiness(mask waiter.EventMask) waiter.EventMask {
+ // ep holds messages to be read and thus handles EventIn readiness.
+ ready := s.ep.Readiness(mask)
+
+ if mask&waiter.EventOut == waiter.EventOut {
+ // sendMsg handles messages synchronously and is thus always
+ // ready for writing.
+ ready |= waiter.EventOut
+ }
+
+ return ready
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *Socket) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.ep.EventRegister(e, mask)
+ // Writable readiness never changes, so no registration is needed.
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *Socket) EventUnregister(e *waiter.Entry) {
+ s.ep.EventUnregister(e)
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *Socket) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ // TODO: no ioctls supported.
+ return 0, syserror.ENOTTY
+}
+
+// ExtractSockAddr extracts the SockAddrNetlink from b.
+func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) {
+ if len(b) < linux.SockAddrNetlinkSize {
+ return nil, syserr.ErrBadAddress
+ }
+
+ var sa linux.SockAddrNetlink
+ binary.Unmarshal(b[:linux.SockAddrNetlinkSize], usermem.ByteOrder, &sa)
+
+ if sa.Family != linux.AF_NETLINK {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return &sa, nil
+}
+
+// bindPort binds this socket to a port, preferring 'port' if it is available.
+//
+// port of 0 defaults to the ThreadGroup ID.
+//
+// Preconditions: mu is held.
+func (s *Socket) bindPort(t *kernel.Task, port int32) *syserr.Error {
+ if s.bound {
+ // Re-binding is only allowed if the port doesn't change.
+ if port != s.portID {
+ return syserr.ErrInvalidArgument
+ }
+
+ return nil
+ }
+
+ if port == 0 {
+ port = int32(t.ThreadGroup().ID())
+ }
+ port, ok := s.ports.Allocate(s.protocol.Protocol(), port)
+ if !ok {
+ return syserr.ErrBusy
+ }
+
+ s.portID = port
+ s.bound = true
+ return nil
+}
+
+// Bind implements socket.Socket.Bind.
+func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ a, err := ExtractSockAddr(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ // No support for multicast groups yet.
+ if a.Groups != 0 {
+ return syserr.ErrPermissionDenied
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ return s.bindPort(t, int32(a.PortID))
+}
+
+// Connect implements socket.Socket.Connect.
+func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ a, err := ExtractSockAddr(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ // No support for multicast groups yet.
+ if a.Groups != 0 {
+ return syserr.ErrPermissionDenied
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ if a.PortID == 0 {
+ // Netlink sockets default to connected to the kernel, but
+ // connecting anyways automatically binds if not already bound.
+ if !s.bound {
+ // Pass port 0 to get an auto-selected port ID.
+ return s.bindPort(t, 0)
+ }
+ return nil
+ }
+
+ // We don't support non-kernel destination ports. Linux returns EPERM
+ // if applications attempt to do this without NL_CFG_F_NONROOT_SEND, so
+ // we emulate that.
+ return syserr.ErrPermissionDenied
+}
+
+// Accept implements socket.Socket.Accept.
+func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+ // Netlink sockets never support accept.
+ return 0, nil, 0, syserr.ErrNotSupported
+}
+
+// Listen implements socket.Socket.Listen.
+func (s *Socket) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ // Netlink sockets never support listen.
+ return syserr.ErrNotSupported
+}
+
+// Shutdown implements socket.Socket.Shutdown.
+func (s *Socket) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ // Netlink sockets never support shutdown.
+ return syserr.ErrNotSupported
+}
+
+// GetSockOpt implements socket.Socket.GetSockOpt.
+func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) {
+ // TODO: no sockopts supported.
+ return nil, syserr.ErrProtocolNotAvailable
+}
+
+// SetSockOpt implements socket.Socket.SetSockOpt.
+func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
+ // TODO: no sockopts supported.
+ return syserr.ErrProtocolNotAvailable
+}
+
+// GetSockName implements socket.Socket.GetSockName.
+func (s *Socket) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ sa := linux.SockAddrNetlink{
+ Family: linux.AF_NETLINK,
+ PortID: uint32(s.portID),
+ }
+ return sa, uint32(binary.Size(sa)), nil
+}
+
+// GetPeerName implements socket.Socket.GetPeerName.
+func (s *Socket) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ sa := linux.SockAddrNetlink{
+ Family: linux.AF_NETLINK,
+ // TODO: Support non-kernel peers. For now the peer
+ // must be the kernel.
+ PortID: 0,
+ }
+ return sa, uint32(binary.Size(sa)), nil
+}
+
+// RecvMsg implements socket.Socket.RecvMsg.
+func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, interface{}, uint32, unix.ControlMessages, *syserr.Error) {
+ from := linux.SockAddrNetlink{
+ Family: linux.AF_NETLINK,
+ PortID: 0,
+ }
+ fromLen := uint32(binary.Size(from))
+
+ trunc := flags&linux.MSG_TRUNC != 0
+
+ r := sunix.EndpointReader{
+ Endpoint: s.ep,
+ Peek: flags&linux.MSG_PEEK != 0,
+ }
+
+ if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ if trunc {
+ n = int64(r.MsgSize)
+ }
+ return int(n), from, fromLen, unix.ControlMessages{}, syserr.FromError(err)
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // receive all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ for {
+ if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
+ if trunc {
+ n = int64(r.MsgSize)
+ }
+ return int(n), from, fromLen, unix.ControlMessages{}, syserr.FromError(err)
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return 0, nil, 0, unix.ControlMessages{}, syserr.ErrTryAgain
+ }
+ return 0, nil, 0, unix.ControlMessages{}, syserr.FromError(err)
+ }
+ }
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &sunix.EndpointReader{
+ Endpoint: s.ep,
+ })
+}
+
+// sendResponse sends the response messages in ms back to userspace.
+func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error {
+ // Linux combines multiple netlink messages into a single datagram.
+ bufs := make([][]byte, 0, len(ms.Messages))
+ for _, m := range ms.Messages {
+ bufs = append(bufs, m.Finalize())
+ }
+
+ if len(bufs) > 0 {
+ // RecvMsg never receives the address, so we don't need to send
+ // one.
+ _, notify, terr := s.connection.Send(bufs, unix.ControlMessages{}, tcpip.FullAddress{})
+ // If the buffer is full, we simply drop messages, just like
+ // Linux.
+ if terr != nil && terr != tcpip.ErrWouldBlock {
+ return syserr.TranslateNetstackError(terr)
+ }
+ if notify {
+ s.connection.SendNotify()
+ }
+ }
+
+ // N.B. multi-part messages should still send NLMSG_DONE even if
+ // MessageSet contains no messages.
+ //
+ // N.B. NLMSG_DONE is always sent in a different datagram. See
+ // net/netlink/af_netlink.c:netlink_dump.
+ if ms.Multi {
+ m := NewMessage(linux.NetlinkMessageHeader{
+ Type: linux.NLMSG_DONE,
+ Flags: linux.NLM_F_MULTI,
+ Seq: ms.Seq,
+ PortID: uint32(ms.PortID),
+ })
+
+ _, notify, terr := s.connection.Send([][]byte{m.Finalize()}, unix.ControlMessages{}, tcpip.FullAddress{})
+ if terr != nil && terr != tcpip.ErrWouldBlock {
+ return syserr.TranslateNetstackError(terr)
+ }
+ if notify {
+ s.connection.SendNotify()
+ }
+ }
+
+ return nil
+}
+
+// processMessages handles each message in buf, passing it to the protocol
+// handler for final handling.
+func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error {
+ for len(buf) > 0 {
+ if len(buf) < linux.NetlinkMessageHeaderSize {
+ // Linux ignores messages that are too short. See
+ // net/netlink/af_netlink.c:netlink_rcv_skb.
+ break
+ }
+
+ var hdr linux.NetlinkMessageHeader
+ binary.Unmarshal(buf[:linux.NetlinkMessageHeaderSize], usermem.ByteOrder, &hdr)
+
+ if hdr.Length < linux.NetlinkMessageHeaderSize || uint64(hdr.Length) > uint64(len(buf)) {
+ // Linux ignores malformed messages. See
+ // net/netlink/af_netlink.c:netlink_rcv_skb.
+ break
+ }
+
+ // Data from this message.
+ data := buf[linux.NetlinkMessageHeaderSize:hdr.Length]
+
+ // Advance to the next message.
+ next := alignUp(int(hdr.Length), linux.NLMSG_ALIGNTO)
+ if next >= len(buf)-1 {
+ next = len(buf) - 1
+ }
+ buf = buf[next:]
+
+ // Ignore control messages.
+ if hdr.Type < linux.NLMSG_MIN_TYPE {
+ continue
+ }
+
+ // TODO: ACKs not supported yet.
+ if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK {
+ return syserr.ErrNotSupported
+ }
+
+ ms := NewMessageSet(s.portID, hdr.Seq)
+ if err := s.protocol.ProcessMessage(ctx, hdr, data, ms); err != nil {
+ return err
+ }
+
+ if err := s.sendResponse(ctx, ms); err != nil {
+ return err
+ }
+ }
+
+ return nil
+}
+
+// sendMsg is the core of message send, used for SendMsg and Write.
+func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages unix.ControlMessages) (int, *syserr.Error) {
+ dstPort := int32(0)
+
+ if len(to) != 0 {
+ a, err := ExtractSockAddr(to)
+ if err != nil {
+ return 0, err
+ }
+
+ // No support for multicast groups yet.
+ if a.Groups != 0 {
+ return 0, syserr.ErrPermissionDenied
+ }
+
+ dstPort = int32(a.PortID)
+ }
+
+ if dstPort != 0 {
+ // Non-kernel destinations not supported yet. Treat as if
+ // NL_CFG_F_NONROOT_SEND is not set.
+ return 0, syserr.ErrPermissionDenied
+ }
+
+ s.mu.Lock()
+ defer s.mu.Unlock()
+
+ // For simplicity, and consistency with Linux, we copy in the entire
+ // message up front.
+ if uint64(src.NumBytes()) > s.sendBufferSize {
+ return 0, syserr.ErrMessageTooLong
+ }
+
+ buf := make([]byte, src.NumBytes())
+ n, err := src.CopyIn(ctx, buf)
+ if err != nil {
+ // Don't partially consume messages.
+ return 0, syserr.FromError(err)
+ }
+
+ if err := s.processMessages(ctx, buf); err != nil {
+ return 0, err
+ }
+
+ return n, nil
+}
+
+// SendMsg implements socket.Socket.SendMsg.
+func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages unix.ControlMessages) (int, *syserr.Error) {
+ return s.sendMsg(t, src, to, flags, controlMessages)
+}
+
+// Write implements fs.FileOperations.Write.
+func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ n, err := s.sendMsg(ctx, src, nil, 0, unix.ControlMessages{})
+ return int64(n), err.ToError()
+}
diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD
new file mode 100644
index 000000000..b0351b363
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/BUILD
@@ -0,0 +1,59 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library")
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "rpcinet",
+ srcs = [
+ "device.go",
+ "rpcinet.go",
+ "socket.go",
+ "stack.go",
+ "stack_unsafe.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ ":syscall_rpc_go_proto",
+ "//pkg/abi/linux",
+ "//pkg/binary",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/context",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/inet",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/kdefs",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/hostinet",
+ "//pkg/sentry/socket/rpcinet/conn",
+ "//pkg/sentry/socket/rpcinet/notifier",
+ "//pkg/sentry/usermem",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip/buffer",
+ "//pkg/tcpip/transport/unix",
+ "//pkg/unet",
+ "//pkg/waiter",
+ ],
+)
+
+proto_library(
+ name = "syscall_rpc_proto",
+ srcs = ["syscall_rpc.proto"],
+ visibility = [
+ "//visibility:public",
+ ],
+)
+
+go_proto_library(
+ name = "syscall_rpc_go_proto",
+ importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto",
+ proto = ":syscall_rpc_proto",
+ visibility = [
+ "//visibility:public",
+ ],
+)
diff --git a/pkg/sentry/socket/rpcinet/conn/BUILD b/pkg/sentry/socket/rpcinet/conn/BUILD
new file mode 100644
index 000000000..4923dee4b
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/conn/BUILD
@@ -0,0 +1,17 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "conn",
+ srcs = ["conn.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/conn",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/binary",
+ "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto",
+ "//pkg/syserr",
+ "//pkg/unet",
+ "@com_github_golang_protobuf//proto:go_default_library",
+ ],
+)
diff --git a/pkg/sentry/socket/rpcinet/conn/conn.go b/pkg/sentry/socket/rpcinet/conn/conn.go
new file mode 100644
index 000000000..ea6ec87ed
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/conn/conn.go
@@ -0,0 +1,167 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package conn is an RPC connection to a syscall RPC server.
+package conn
+
+import (
+ "fmt"
+ "sync"
+ "sync/atomic"
+ "syscall"
+
+ "github.com/golang/protobuf/proto"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/unet"
+
+ pb "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
+)
+
+type request struct {
+ response []byte
+ ready chan struct{}
+ ignoreResult bool
+}
+
+// RPCConnection represents a single RPC connection to a syscall gofer.
+type RPCConnection struct {
+ // reqID is the ID of the last request and must be accessed atomically.
+ reqID uint64
+
+ sendMu sync.Mutex
+ socket *unet.Socket
+
+ reqMu sync.Mutex
+ requests map[uint64]request
+}
+
+// NewRPCConnection initializes a RPC connection to a socket gofer.
+func NewRPCConnection(s *unet.Socket) *RPCConnection {
+ conn := &RPCConnection{socket: s, requests: map[uint64]request{}}
+ go func() { // S/R-FIXME
+ var nums [16]byte
+ for {
+ for n := 0; n < len(nums); {
+ nn, err := conn.socket.Read(nums[n:])
+ if err != nil {
+ panic(fmt.Sprint("error reading length from socket rpc gofer: ", err))
+ }
+ n += nn
+ }
+
+ b := make([]byte, binary.LittleEndian.Uint64(nums[:8]))
+ id := binary.LittleEndian.Uint64(nums[8:])
+
+ for n := 0; n < len(b); {
+ nn, err := conn.socket.Read(b[n:])
+ if err != nil {
+ panic(fmt.Sprint("error reading request from socket rpc gofer: ", err))
+ }
+ n += nn
+ }
+
+ conn.reqMu.Lock()
+ r := conn.requests[id]
+ if r.ignoreResult {
+ delete(conn.requests, id)
+ } else {
+ r.response = b
+ conn.requests[id] = r
+ }
+ conn.reqMu.Unlock()
+ close(r.ready)
+ }
+ }()
+ return conn
+}
+
+// NewRequest makes a request to the RPC gofer and returns the request ID and a
+// channel which will be closed once the request completes.
+func (c *RPCConnection) NewRequest(req pb.SyscallRequest, ignoreResult bool) (uint64, chan struct{}) {
+ b, err := proto.Marshal(&req)
+ if err != nil {
+ panic(fmt.Sprint("invalid proto: ", err))
+ }
+
+ id := atomic.AddUint64(&c.reqID, 1)
+ ch := make(chan struct{})
+
+ c.reqMu.Lock()
+ c.requests[id] = request{ready: ch, ignoreResult: ignoreResult}
+ c.reqMu.Unlock()
+
+ c.sendMu.Lock()
+ defer c.sendMu.Unlock()
+
+ var nums [16]byte
+ binary.LittleEndian.PutUint64(nums[:8], uint64(len(b)))
+ binary.LittleEndian.PutUint64(nums[8:], id)
+ for n := 0; n < len(nums); {
+ nn, err := c.socket.Write(nums[n:])
+ if err != nil {
+ panic(fmt.Sprint("error writing length and ID to socket gofer: ", err))
+ }
+ n += nn
+ }
+
+ for n := 0; n < len(b); {
+ nn, err := c.socket.Write(b[n:])
+ if err != nil {
+ panic(fmt.Sprint("error writing request to socket gofer: ", err))
+ }
+ n += nn
+ }
+
+ return id, ch
+}
+
+// RPCReadFile will execute the ReadFile helper RPC method which avoids the
+// common pattern of open(2), read(2), close(2) by doing all three operations
+// as a single RPC. It will read the entire file or return EFBIG if the file
+// was too large.
+func (c *RPCConnection) RPCReadFile(path string) ([]byte, *syserr.Error) {
+ req := &pb.SyscallRequest_ReadFile{&pb.ReadFileRequest{
+ Path: path,
+ }}
+
+ id, ch := c.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
+ <-ch
+
+ res := c.Request(id).Result.(*pb.SyscallResponse_ReadFile).ReadFile.Result
+ if e, ok := res.(*pb.ReadFileResponse_ErrorNumber); ok {
+ return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ return res.(*pb.ReadFileResponse_Data).Data, nil
+}
+
+// Request retrieves the request corresponding to the given request ID.
+//
+// The channel returned by NewRequest must have been closed before Request can
+// be called. This will happen automatically, do not manually close the
+// channel.
+func (c *RPCConnection) Request(id uint64) pb.SyscallResponse {
+ c.reqMu.Lock()
+ r := c.requests[id]
+ delete(c.requests, id)
+ c.reqMu.Unlock()
+
+ var resp pb.SyscallResponse
+ if err := proto.Unmarshal(r.response, &resp); err != nil {
+ panic(fmt.Sprint("invalid proto: ", err))
+ }
+
+ return resp
+}
diff --git a/pkg/sentry/socket/rpcinet/device.go b/pkg/sentry/socket/rpcinet/device.go
new file mode 100644
index 000000000..f7b63436e
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/device.go
@@ -0,0 +1,19 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package rpcinet
+
+import "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+
+var socketDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/rpcinet/notifier/BUILD b/pkg/sentry/socket/rpcinet/notifier/BUILD
new file mode 100644
index 000000000..6f3b06a05
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/notifier/BUILD
@@ -0,0 +1,15 @@
+package(licenses = ["notice"]) # BSD
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+
+go_library(
+ name = "notifier",
+ srcs = ["notifier.go"],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/notifier",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto",
+ "//pkg/sentry/socket/rpcinet/conn",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/socket/rpcinet/notifier/notifier.go b/pkg/sentry/socket/rpcinet/notifier/notifier.go
new file mode 100644
index 000000000..f88a908ed
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/notifier/notifier.go
@@ -0,0 +1,230 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package notifier implements an FD notifier implementation over RPC.
+package notifier
+
+import (
+ "fmt"
+ "sync"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/conn"
+ pb "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+type fdInfo struct {
+ queue *waiter.Queue
+ waiting bool
+}
+
+// Notifier holds all the state necessary to issue notifications when IO events
+// occur in the observed FDs.
+type Notifier struct {
+ // rpcConn is the connection that is used for sending RPCs.
+ rpcConn *conn.RPCConnection
+
+ // epFD is the epoll file descriptor used to register for io
+ // notifications.
+ epFD uint32
+
+ // mu protects fdMap.
+ mu sync.Mutex
+
+ // fdMap maps file descriptors to their notification queues and waiting
+ // status.
+ fdMap map[uint32]*fdInfo
+}
+
+// NewRPCNotifier creates a new notifier object.
+func NewRPCNotifier(cn *conn.RPCConnection) (*Notifier, error) {
+ id, c := cn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCreate1{&pb.EpollCreate1Request{}}}, false /* ignoreResult */)
+ <-c
+
+ res := cn.Request(id).Result.(*pb.SyscallResponse_EpollCreate1).EpollCreate1.Result
+ if e, ok := res.(*pb.EpollCreate1Response_ErrorNumber); ok {
+ return nil, syscall.Errno(e.ErrorNumber)
+ }
+
+ w := &Notifier{
+ rpcConn: cn,
+ epFD: res.(*pb.EpollCreate1Response_Fd).Fd,
+ fdMap: make(map[uint32]*fdInfo),
+ }
+
+ go w.waitAndNotify() // S/R-FIXME
+
+ return w, nil
+}
+
+// waitFD waits on mask for fd. The fdMap mutex must be hold.
+func (n *Notifier) waitFD(fd uint32, fi *fdInfo, mask waiter.EventMask) error {
+ if !fi.waiting && mask == 0 {
+ return nil
+ }
+
+ e := pb.EpollEvent{
+ Events: uint32(mask) | -syscall.EPOLLET,
+ Fd: fd,
+ }
+
+ switch {
+ case !fi.waiting && mask != 0:
+ id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_ADD, Fd: fd, Event: &e}}}, false /* ignoreResult */)
+ <-c
+
+ e := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollCtl).EpollCtl.ErrorNumber
+ if e != 0 {
+ return syscall.Errno(e)
+ }
+
+ fi.waiting = true
+ case fi.waiting && mask == 0:
+ id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_DEL, Fd: fd}}}, false /* ignoreResult */)
+ <-c
+ n.rpcConn.Request(id)
+
+ fi.waiting = false
+ case fi.waiting && mask != 0:
+ id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_MOD, Fd: fd, Event: &e}}}, false /* ignoreResult */)
+ <-c
+
+ e := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollCtl).EpollCtl.ErrorNumber
+ if e != 0 {
+ return syscall.Errno(e)
+ }
+ }
+
+ return nil
+}
+
+// addFD adds an FD to the list of FDs observed by n.
+func (n *Notifier) addFD(fd uint32, queue *waiter.Queue) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ // Panic if we're already notifying on this FD.
+ if _, ok := n.fdMap[fd]; ok {
+ panic(fmt.Sprintf("File descriptor %d added twice", fd))
+ }
+
+ // We have nothing to wait for at the moment. Just add it to the map.
+ n.fdMap[fd] = &fdInfo{queue: queue}
+}
+
+// updateFD updates the set of events the FD needs to be notified on.
+func (n *Notifier) updateFD(fd uint32) error {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ if fi, ok := n.fdMap[fd]; ok {
+ return n.waitFD(fd, fi, fi.queue.Events())
+ }
+
+ return nil
+}
+
+// RemoveFD removes an FD from the list of FDs observed by n.
+func (n *Notifier) removeFD(fd uint32) {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ // Remove from map, then from epoll object.
+ n.waitFD(fd, n.fdMap[fd], 0)
+ delete(n.fdMap, fd)
+}
+
+// hasFD returns true if the FD is in the list of observed FDs.
+func (n *Notifier) hasFD(fd uint32) bool {
+ n.mu.Lock()
+ defer n.mu.Unlock()
+
+ _, ok := n.fdMap[fd]
+ return ok
+}
+
+// waitAndNotify loops waiting for io event notifications from the epoll
+// object. Once notifications arrive, they are dispatched to the
+// registered queue.
+func (n *Notifier) waitAndNotify() error {
+ for {
+ id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollWait{&pb.EpollWaitRequest{Fd: n.epFD, NumEvents: 100, Msec: -1}}}, false /* ignoreResult */)
+ <-c
+
+ res := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollWait).EpollWait.Result
+ if e, ok := res.(*pb.EpollWaitResponse_ErrorNumber); ok {
+ err := syscall.Errno(e.ErrorNumber)
+ // NOTE: I don't think epoll_wait can return EAGAIN but I'm being
+ // conseratively careful here since exiting the notification thread
+ // would be really bad.
+ if err == syscall.EINTR || err == syscall.EAGAIN {
+ continue
+ }
+ return err
+ }
+
+ n.mu.Lock()
+ for _, e := range res.(*pb.EpollWaitResponse_Events).Events.Events {
+ if fi, ok := n.fdMap[e.Fd]; ok {
+ fi.queue.Notify(waiter.EventMask(e.Events))
+ }
+ }
+ n.mu.Unlock()
+ }
+}
+
+// AddFD adds an FD to the list of observed FDs.
+func (n *Notifier) AddFD(fd uint32, queue *waiter.Queue) error {
+ n.addFD(fd, queue)
+ return nil
+}
+
+// UpdateFD updates the set of events the FD needs to be notified on.
+func (n *Notifier) UpdateFD(fd uint32) error {
+ return n.updateFD(fd)
+}
+
+// RemoveFD removes an FD from the list of observed FDs.
+func (n *Notifier) RemoveFD(fd uint32) {
+ n.removeFD(fd)
+}
+
+// HasFD returns true if the FD is in the list of observed FDs.
+//
+// This should only be used by tests to assert that FDs are correctly
+// registered.
+func (n *Notifier) HasFD(fd uint32) bool {
+ return n.hasFD(fd)
+}
+
+// NonBlockingPoll polls the given fd in non-blocking fashion. It is used just
+// to query the FD's current state; this method will block on the RPC response
+// although the syscall is non-blocking.
+func (n *Notifier) NonBlockingPoll(fd uint32, mask waiter.EventMask) waiter.EventMask {
+ for {
+ id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Poll{&pb.PollRequest{Fd: fd, Events: uint32(mask)}}}, false /* ignoreResult */)
+ <-c
+
+ res := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_Poll).Poll.Result
+ if e, ok := res.(*pb.PollResponse_ErrorNumber); ok {
+ if syscall.Errno(e.ErrorNumber) == syscall.EINTR {
+ continue
+ }
+ return mask
+ }
+
+ return waiter.EventMask(res.(*pb.PollResponse_Events).Events)
+ }
+}
diff --git a/pkg/sentry/socket/rpcinet/rpcinet.go b/pkg/sentry/socket/rpcinet/rpcinet.go
new file mode 100644
index 000000000..10b0dedc2
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/rpcinet.go
@@ -0,0 +1,16 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package rpcinet implements sockets using an RPC for each syscall.
+package rpcinet
diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go
new file mode 100644
index 000000000..574d99ba5
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/socket.go
@@ -0,0 +1,567 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package rpcinet
+
+import (
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/conn"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/notifier"
+ pb "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/buffer"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// socketOperations implements fs.FileOperations and socket.Socket for a socket
+// implemented using a host socket.
+type socketOperations struct {
+ socket.ReceiveTimeout
+ fsutil.PipeSeek `state:"nosave"`
+ fsutil.NotDirReaddir `state:"nosave"`
+ fsutil.NoFsync `state:"nosave"`
+ fsutil.NoopFlush `state:"nosave"`
+ fsutil.NoMMap `state:"nosave"`
+
+ fd uint32 // must be O_NONBLOCK
+ wq *waiter.Queue
+ rpcConn *conn.RPCConnection
+ notifier *notifier.Notifier
+}
+
+// Verify that we actually implement socket.Socket.
+var _ = socket.Socket(&socketOperations{})
+
+// New creates a new RPC socket.
+func newSocketFile(ctx context.Context, stack *Stack, family int, skType int, protocol int) (*fs.File, *syserr.Error) {
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(family), Type: int64(skType | syscall.SOCK_NONBLOCK), Protocol: int64(protocol)}}}, false /* ignoreResult */)
+ <-c
+
+ res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Socket).Socket.Result
+ if e, ok := res.(*pb.SocketResponse_ErrorNumber); ok {
+ return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+ fd := res.(*pb.SocketResponse_Fd).Fd
+
+ var wq waiter.Queue
+ stack.notifier.AddFD(fd, &wq)
+
+ dirent := socket.NewDirent(ctx, socketDevice)
+ return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &socketOperations{
+ wq: &wq,
+ fd: fd,
+ rpcConn: stack.rpcConn,
+ notifier: stack.notifier,
+ }), nil
+}
+
+func isBlockingErrno(err error) bool {
+ return err == syscall.EAGAIN || err == syscall.EWOULDBLOCK
+}
+
+func translateIOSyscallError(err error) error {
+ if isBlockingErrno(err) {
+ return syserror.ErrWouldBlock
+ }
+ return err
+}
+
+// Release implements fs.FileOperations.Release.
+func (s *socketOperations) Release() {
+ s.notifier.RemoveFD(s.fd)
+
+ // We always need to close the FD.
+ _, _ = s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Close{&pb.CloseRequest{Fd: s.fd}}}, true /* ignoreResult */)
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *socketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.notifier.NonBlockingPoll(s.fd, mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *socketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.wq.EventRegister(e, mask)
+ s.notifier.UpdateFD(s.fd)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *socketOperations) EventUnregister(e *waiter.Entry) {
+ s.wq.EventUnregister(e)
+ s.notifier.UpdateFD(s.fd)
+}
+
+func rpcRead(t *kernel.Task, req *pb.SyscallRequest_Read) (*pb.ReadResponse_Data, *syserr.Error) {
+ s := t.NetworkContext().(*Stack)
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
+ <-c
+
+ res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Read).Read.Result
+ if e, ok := res.(*pb.ReadResponse_ErrorNumber); ok {
+ return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ return res.(*pb.ReadResponse_Data), nil
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ req := &pb.SyscallRequest_Read{&pb.ReadRequest{
+ Fd: s.fd,
+ Length: uint32(dst.NumBytes()),
+ }}
+
+ res, se := rpcRead(ctx.(*kernel.Task), req)
+ if se == nil {
+ n, e := dst.CopyOut(ctx, res.Data)
+ return int64(n), e
+ }
+ if se != syserr.ErrWouldBlock {
+ return 0, se.ToError()
+ }
+
+ // We'll have to block. Register for notifications and read again when ready.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ for {
+ res, se := rpcRead(ctx.(*kernel.Task), req)
+ if se == nil {
+ n, e := dst.CopyOut(ctx, res.Data)
+ return int64(n), e
+ }
+ if se != syserr.ErrWouldBlock {
+ return 0, se.ToError()
+ }
+
+ if err := ctx.(*kernel.Task).Block(ch); err != nil {
+ return 0, err
+ }
+ }
+}
+
+func rpcWrite(t *kernel.Task, req *pb.SyscallRequest_Write) (uint32, *syserr.Error) {
+ s := t.NetworkContext().(*Stack)
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
+ <-c
+
+ res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Write).Write.Result
+ if e, ok := res.(*pb.WriteResponse_ErrorNumber); ok {
+ return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ return res.(*pb.WriteResponse_Length).Length, nil
+}
+
+// Write implements fs.FileOperations.Write.
+func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ t := ctx.(*kernel.Task)
+ v := buffer.NewView(int(src.NumBytes()))
+
+ // Copy all the data into the buffer.
+ if _, err := src.CopyIn(t, v); err != nil {
+ return 0, err
+ }
+
+ n, err := rpcWrite(t, &pb.SyscallRequest_Write{&pb.WriteRequest{Fd: s.fd, Data: v}})
+ return int64(n), err.ToError()
+}
+
+func rpcConnect(t *kernel.Task, fd uint32, sockaddr []byte) *syserr.Error {
+ s := t.NetworkContext().(*Stack)
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Connect{&pb.ConnectRequest{Fd: uint32(fd), Address: sockaddr}}}, false /* ignoreResult */)
+ <-c
+
+ if e := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Connect).Connect.ErrorNumber; e != 0 {
+ return syserr.FromHost(syscall.Errno(e))
+ }
+ return nil
+}
+
+// Connect implements socket.Socket.Connect.
+func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ if !blocking {
+ return rpcConnect(t, s.fd, sockaddr)
+ }
+
+ // Register for notification when the endpoint becomes writable, then
+ // initiate the connection.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+
+ if err := rpcConnect(t, s.fd, sockaddr); err != syserr.ErrConnectStarted && err != syserr.ErrAlreadyConnecting {
+ return err
+ }
+
+ // It's pending, so we have to wait for a notification, and fetch the
+ // result once the wait completes.
+ if err := t.Block(ch); err != nil {
+ return syserr.FromError(err)
+ }
+
+ // Call Connect() again after blocking to find connect's result.
+ return rpcConnect(t, s.fd, sockaddr)
+}
+
+func rpcAccept(t *kernel.Task, fd uint32, peer bool) (*pb.AcceptResponse_ResultPayload, *syserr.Error) {
+ stack := t.NetworkContext().(*Stack)
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Accept{&pb.AcceptRequest{Fd: fd, Peer: peer, Flags: syscall.SOCK_NONBLOCK}}}, false /* ignoreResult */)
+ <-c
+
+ res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Accept).Accept.Result
+ if e, ok := res.(*pb.AcceptResponse_ErrorNumber); ok {
+ return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+ return res.(*pb.AcceptResponse_Payload).Payload, nil
+}
+
+// Accept implements socket.Socket.Accept.
+func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+ payload, se := rpcAccept(t, s.fd, peerRequested)
+
+ // Check if we need to block.
+ if blocking && se == syserr.ErrWouldBlock {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ // Try to accept the connection again; if it fails, then wait until we
+ // get a notification.
+ for {
+ if payload, se = rpcAccept(t, s.fd, peerRequested); se != syserr.ErrWouldBlock {
+ break
+ }
+
+ if err := t.Block(ch); err != nil {
+ return 0, nil, 0, syserr.FromError(err)
+ }
+ }
+ }
+
+ // Handle any error from accept.
+ if se != nil {
+ return 0, nil, 0, se
+ }
+
+ var wq waiter.Queue
+ s.notifier.AddFD(payload.Fd, &wq)
+
+ dirent := socket.NewDirent(t, socketDevice)
+ file := fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonBlocking: flags&linux.SOCK_NONBLOCK != 0}, &socketOperations{
+ wq: &wq,
+ fd: payload.Fd,
+ notifier: s.notifier,
+ })
+
+ fdFlags := kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ }
+ fd, err := t.FDMap().NewFDFrom(0, file, fdFlags, t.ThreadGroup().Limits())
+ if err != nil {
+ return 0, nil, 0, syserr.FromError(err)
+ }
+
+ return fd, payload.Address.Address, payload.Address.Length, nil
+}
+
+// Bind implements socket.Socket.Bind.
+func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ stack := t.NetworkContext().(*Stack)
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Bind{&pb.BindRequest{Fd: s.fd, Address: sockaddr}}}, false /* ignoreResult */)
+ <-c
+
+ if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Bind).Bind.ErrorNumber; e != 0 {
+ syserr.FromHost(syscall.Errno(e))
+ }
+ return nil
+}
+
+// Listen implements socket.Socket.Listen.
+func (s *socketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ stack := t.NetworkContext().(*Stack)
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Listen{&pb.ListenRequest{Fd: s.fd, Backlog: int64(backlog)}}}, false /* ignoreResult */)
+ <-c
+
+ if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Listen).Listen.ErrorNumber; e != 0 {
+ syserr.FromHost(syscall.Errno(e))
+ }
+ return nil
+}
+
+// Shutdown implements socket.Socket.Shutdown.
+func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ stack := t.NetworkContext().(*Stack)
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Shutdown{&pb.ShutdownRequest{Fd: s.fd, How: int64(how)}}}, false /* ignoreResult */)
+ <-c
+
+ if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Shutdown).Shutdown.ErrorNumber; e != 0 {
+ return syserr.FromHost(syscall.Errno(e))
+ }
+ return nil
+}
+
+// GetSockOpt implements socket.Socket.GetSockOpt.
+func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error) {
+ stack := t.NetworkContext().(*Stack)
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockOpt{&pb.GetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Length: uint32(outLen)}}}, false /* ignoreResult */)
+ <-c
+
+ res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetSockOpt).GetSockOpt.Result
+ if e, ok := res.(*pb.GetSockOptResponse_ErrorNumber); ok {
+ return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ return res.(*pb.GetSockOptResponse_Opt).Opt, nil
+}
+
+// SetSockOpt implements socket.Socket.SetSockOpt.
+func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error {
+ stack := t.NetworkContext().(*Stack)
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_SetSockOpt{&pb.SetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Opt: opt}}}, false /* ignoreResult */)
+ <-c
+
+ if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_SetSockOpt).SetSockOpt.ErrorNumber; e != 0 {
+ syserr.FromHost(syscall.Errno(e))
+ }
+ return nil
+}
+
+// GetPeerName implements socket.Socket.GetPeerName.
+func (s *socketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ stack := t.NetworkContext().(*Stack)
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetPeerName{&pb.GetPeerNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
+ <-c
+
+ res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetPeerName).GetPeerName.Result
+ if e, ok := res.(*pb.GetPeerNameResponse_ErrorNumber); ok {
+ return nil, 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ addr := res.(*pb.GetPeerNameResponse_Address).Address
+ return addr.Address, addr.Length, nil
+}
+
+// GetSockName implements socket.Socket.GetSockName.
+func (s *socketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ stack := t.NetworkContext().(*Stack)
+ id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockName{&pb.GetSockNameRequest{Fd: s.fd}}}, false /* ignoreResult */)
+ <-c
+
+ res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetSockName).GetSockName.Result
+ if e, ok := res.(*pb.GetSockNameResponse_ErrorNumber); ok {
+ return nil, 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ addr := res.(*pb.GetSockNameResponse_Address).Address
+ return addr.Address, addr.Length, nil
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *socketOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return 0, syserror.ENOTTY
+}
+
+func rpcRecvMsg(t *kernel.Task, req *pb.SyscallRequest_Recvmsg) (*pb.RecvmsgResponse_ResultPayload, *syserr.Error) {
+ s := t.NetworkContext().(*Stack)
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
+ <-c
+
+ res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Recvmsg).Recvmsg.Result
+ if e, ok := res.(*pb.RecvmsgResponse_ErrorNumber); ok {
+ return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ return res.(*pb.RecvmsgResponse_Payload).Payload, nil
+}
+
+// RecvMsg implements socket.Socket.RecvMsg.
+func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, interface{}, uint32, unix.ControlMessages, *syserr.Error) {
+ req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{
+ Fd: s.fd,
+ Length: uint32(dst.NumBytes()),
+ Sender: senderRequested,
+ Trunc: flags&linux.MSG_TRUNC != 0,
+ Peek: flags&linux.MSG_PEEK != 0,
+ }}
+
+ res, err := rpcRecvMsg(t, req)
+ if err == nil {
+ n, e := dst.CopyOut(t, res.Data)
+ return int(n), res.Address.GetAddress(), res.Address.GetLength(), unix.ControlMessages{}, syserr.FromError(e)
+ }
+ if err != syserr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ return 0, nil, 0, unix.ControlMessages{}, err
+ }
+
+ // We'll have to block. Register for notifications and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ for {
+ res, err := rpcRecvMsg(t, req)
+ if err == nil {
+ n, e := dst.CopyOut(t, res.Data)
+ return int(n), res.Address.GetAddress(), res.Address.GetLength(), unix.ControlMessages{}, syserr.FromError(e)
+ }
+ if err != syserr.ErrWouldBlock {
+ return 0, nil, 0, unix.ControlMessages{}, err
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return 0, nil, 0, unix.ControlMessages{}, syserr.ErrTryAgain
+ }
+ return 0, nil, 0, unix.ControlMessages{}, syserr.FromError(err)
+ }
+ }
+}
+
+func rpcSendMsg(t *kernel.Task, req *pb.SyscallRequest_Sendmsg) (uint32, *syserr.Error) {
+ s := t.NetworkContext().(*Stack)
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
+ <-c
+
+ res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Sendmsg).Sendmsg.Result
+ if e, ok := res.(*pb.SendmsgResponse_ErrorNumber); ok {
+ return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ return res.(*pb.SendmsgResponse_Length).Length, nil
+}
+
+// SendMsg implements socket.Socket.SendMsg.
+func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages unix.ControlMessages) (int, *syserr.Error) {
+ // Whitelist flags.
+ if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ // Reject control messages.
+ if !controlMessages.Empty() {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ v := buffer.NewView(int(src.NumBytes()))
+
+ // Copy all the data into the buffer.
+ if _, err := src.CopyIn(t, v); err != nil {
+ return 0, syserr.FromError(err)
+ }
+
+ // TODO: this needs to change to map directly to a SendMsg syscall
+ // in the RPC.
+ req := &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
+ Fd: uint32(s.fd),
+ Data: v,
+ Address: to,
+ More: flags&linux.MSG_MORE != 0,
+ EndOfRecord: flags&linux.MSG_EOR != 0,
+ }}
+
+ n, err := rpcSendMsg(t, req)
+ if err != syserr.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ return int(n), err
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+
+ for {
+ n, err := rpcSendMsg(t, req)
+ if err != syserr.ErrWouldBlock {
+ return int(n), err
+ }
+
+ if err := t.Block(ch); err != nil {
+ return 0, syserr.FromError(err)
+ }
+ }
+}
+
+type socketProvider struct {
+ family int
+}
+
+// Socket implements socket.Provider.Socket.
+func (p *socketProvider) Socket(t *kernel.Task, stypeflags unix.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Check that we are using the RPC network stack.
+ stack := t.NetworkContext()
+ if stack == nil {
+ return nil, nil
+ }
+
+ s, ok := stack.(*Stack)
+ if !ok {
+ return nil, nil
+ }
+
+ // Only accept TCP and UDP.
+ //
+ // Try to restrict the flags we will accept to minimize backwards
+ // incompatability with netstack.
+ stype := int(stypeflags) & linux.SOCK_TYPE_MASK
+ switch stype {
+ case syscall.SOCK_STREAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_TCP:
+ // ok
+ default:
+ return nil, nil
+ }
+ case syscall.SOCK_DGRAM:
+ switch protocol {
+ case 0, syscall.IPPROTO_UDP:
+ // ok
+ default:
+ return nil, nil
+ }
+ default:
+ return nil, nil
+ }
+
+ return newSocketFile(t, s, p.family, stype, 0)
+}
+
+// Pair implements socket.Provider.Pair.
+func (p *socketProvider) Pair(t *kernel.Task, stype unix.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+ // Not supported by AF_INET/AF_INET6.
+ return nil, nil, nil
+}
+
+func init() {
+ for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} {
+ socket.RegisterProvider(family, &socketProvider{family})
+ }
+}
diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go
new file mode 100644
index 000000000..503e0e932
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/stack.go
@@ -0,0 +1,175 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package rpcinet
+
+import (
+ "fmt"
+ "strings"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/inet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/hostinet"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/conn"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/notifier"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/unet"
+)
+
+// Stack implements inet.Stack for RPC backed sockets.
+type Stack struct {
+ // We intentionally do not allow these values to be changed to remain
+ // consistent with the other networking stacks.
+ interfaces map[int32]inet.Interface
+ interfaceAddrs map[int32][]inet.InterfaceAddr
+ supportsIPv6 bool
+ tcpRecvBufSize inet.TCPBufferSize
+ tcpSendBufSize inet.TCPBufferSize
+ tcpSACKEnabled bool
+ rpcConn *conn.RPCConnection
+ notifier *notifier.Notifier
+}
+
+func readTCPBufferSizeFile(conn *conn.RPCConnection, filename string) (inet.TCPBufferSize, error) {
+ contents, se := conn.RPCReadFile(filename)
+ if se != nil {
+ return inet.TCPBufferSize{}, fmt.Errorf("failed to read %s: %v", filename, se)
+ }
+ ioseq := usermem.BytesIOSequence(contents)
+ fields := make([]int32, 3)
+ if n, err := usermem.CopyInt32StringsInVec(context.Background(), ioseq.IO, ioseq.Addrs, fields, ioseq.Opts); n != ioseq.NumBytes() || err != nil {
+ return inet.TCPBufferSize{}, fmt.Errorf("failed to parse %s (%q): got %v after %d/%d bytes", filename, contents, err, n, ioseq.NumBytes())
+ }
+ return inet.TCPBufferSize{
+ Min: int(fields[0]),
+ Default: int(fields[1]),
+ Max: int(fields[2]),
+ }, nil
+}
+
+// NewStack returns a Stack containing the current state of the host network
+// stack.
+func NewStack(fd int32) (*Stack, error) {
+ sock, err := unet.NewSocket(int(fd))
+ if err != nil {
+ return nil, err
+ }
+
+ stack := &Stack{
+ interfaces: make(map[int32]inet.Interface),
+ interfaceAddrs: make(map[int32][]inet.InterfaceAddr),
+ rpcConn: conn.NewRPCConnection(sock),
+ }
+
+ var e error
+ stack.notifier, e = notifier.NewRPCNotifier(stack.rpcConn)
+ if e != nil {
+ return nil, e
+ }
+
+ // Load the configuration values from procfs.
+ tcpRMem, e := readTCPBufferSizeFile(stack.rpcConn, "/proc/sys/net/ipv4/tcp_rmem")
+ if e != nil {
+ return nil, e
+ }
+ stack.tcpRecvBufSize = tcpRMem
+
+ tcpWMem, e := readTCPBufferSizeFile(stack.rpcConn, "/proc/sys/net/ipv4/tcp_wmem")
+ if e != nil {
+ return nil, e
+ }
+ stack.tcpSendBufSize = tcpWMem
+
+ ipv6, se := stack.rpcConn.RPCReadFile("/proc/net/if_inet6")
+ if len(string(ipv6)) > 0 {
+ stack.supportsIPv6 = true
+ }
+
+ sackFile := "/proc/sys/net/ipv4/tcp_sack"
+ sack, se := stack.rpcConn.RPCReadFile(sackFile)
+ if se != nil {
+ return nil, fmt.Errorf("failed to read %s: %v", sackFile, se)
+ }
+ stack.tcpSACKEnabled = strings.TrimSpace(string(sack)) != "0"
+
+ links, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETLINK)
+ if err != nil {
+ return nil, fmt.Errorf("RTM_GETLINK failed: %v", err)
+ }
+
+ addrs, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETADDR)
+ if err != nil {
+ return nil, fmt.Errorf("RTM_GETADDR failed: %v", err)
+ }
+
+ e = hostinet.ExtractHostInterfaces(links, addrs, stack.interfaces, stack.interfaceAddrs)
+ if e != nil {
+ return nil, e
+ }
+
+ return stack, nil
+}
+
+// Interfaces implements inet.Stack.Interfaces.
+func (s *Stack) Interfaces() map[int32]inet.Interface {
+ return s.interfaces
+}
+
+// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
+func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
+ return s.interfaceAddrs
+}
+
+// SupportsIPv6 implements inet.Stack.SupportsIPv6.
+func (s *Stack) SupportsIPv6() bool {
+ return s.supportsIPv6
+}
+
+// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
+func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
+ return s.tcpRecvBufSize, nil
+}
+
+// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
+func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
+ // To keep all the supported stacks consistent we don't allow changing this
+ // value even though it would be possible via an RPC.
+ return syserror.EACCES
+}
+
+// TCPSendBufferSize implements inet.Stack.TCPSendBufferSize.
+func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
+ return s.tcpSendBufSize, nil
+}
+
+// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
+func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
+ // To keep all the supported stacks consistent we don't allow changing this
+ // value even though it would be possible via an RPC.
+ return syserror.EACCES
+}
+
+// TCPSACKEnabled implements inet.Stack.TCPSACKEnabled.
+func (s *Stack) TCPSACKEnabled() (bool, error) {
+ return s.tcpSACKEnabled, nil
+}
+
+// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
+func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
+ // To keep all the supported stacks consistent we don't allow changing this
+ // value even though it would be possible via an RPC.
+ return syserror.EACCES
+}
diff --git a/pkg/sentry/socket/rpcinet/stack_unsafe.go b/pkg/sentry/socket/rpcinet/stack_unsafe.go
new file mode 100644
index 000000000..9a896c623
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/stack_unsafe.go
@@ -0,0 +1,193 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package rpcinet
+
+import (
+ "syscall"
+ "unsafe"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/binary"
+ pb "gvisor.googlesource.com/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+)
+
+// NewNetlinkRouteRequest builds a netlink message for getting the RIB,
+// the routing information base.
+func newNetlinkRouteRequest(proto, seq, family int) []byte {
+ rr := &syscall.NetlinkRouteRequest{}
+ rr.Header.Len = uint32(syscall.NLMSG_HDRLEN + syscall.SizeofRtGenmsg)
+ rr.Header.Type = uint16(proto)
+ rr.Header.Flags = syscall.NLM_F_DUMP | syscall.NLM_F_REQUEST
+ rr.Header.Seq = uint32(seq)
+ rr.Data.Family = uint8(family)
+ return netlinkRRtoWireFormat(rr)
+}
+
+func netlinkRRtoWireFormat(rr *syscall.NetlinkRouteRequest) []byte {
+ b := make([]byte, rr.Header.Len)
+ *(*uint32)(unsafe.Pointer(&b[0:4][0])) = rr.Header.Len
+ *(*uint16)(unsafe.Pointer(&b[4:6][0])) = rr.Header.Type
+ *(*uint16)(unsafe.Pointer(&b[6:8][0])) = rr.Header.Flags
+ *(*uint32)(unsafe.Pointer(&b[8:12][0])) = rr.Header.Seq
+ *(*uint32)(unsafe.Pointer(&b[12:16][0])) = rr.Header.Pid
+ b[16] = byte(rr.Data.Family)
+ return b
+}
+
+func (s *Stack) getNetlinkFd() (uint32, *syserr.Error) {
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(syscall.AF_NETLINK), Type: int64(syscall.SOCK_RAW | syscall.SOCK_NONBLOCK), Protocol: int64(syscall.NETLINK_ROUTE)}}}, false /* ignoreResult */)
+ <-c
+
+ res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Socket).Socket.Result
+ if e, ok := res.(*pb.SocketResponse_ErrorNumber); ok {
+ return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+ return res.(*pb.SocketResponse_Fd).Fd, nil
+}
+
+func (s *Stack) bindNetlinkFd(fd uint32, sockaddr []byte) *syserr.Error {
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Bind{&pb.BindRequest{Fd: fd, Address: sockaddr}}}, false /* ignoreResult */)
+ <-c
+
+ if e := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Bind).Bind.ErrorNumber; e != 0 {
+ return syserr.FromHost(syscall.Errno(e))
+ }
+ return nil
+}
+
+func (s *Stack) closeNetlinkFd(fd uint32) {
+ _, _ = s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Close{&pb.CloseRequest{Fd: fd}}}, true /* ignoreResult */)
+}
+
+func (s *Stack) rpcSendMsg(req *pb.SyscallRequest_Sendmsg) (uint32, *syserr.Error) {
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
+ <-c
+
+ res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Sendmsg).Sendmsg.Result
+ if e, ok := res.(*pb.SendmsgResponse_ErrorNumber); ok {
+ return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ return res.(*pb.SendmsgResponse_Length).Length, nil
+}
+
+func (s *Stack) sendMsg(fd uint32, buf []byte, to []byte, flags int) (int, *syserr.Error) {
+ // Whitelist flags.
+ if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 {
+ return 0, syserr.ErrInvalidArgument
+ }
+
+ req := &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{
+ Fd: fd,
+ Data: buf,
+ Address: to,
+ More: flags&linux.MSG_MORE != 0,
+ EndOfRecord: flags&linux.MSG_EOR != 0,
+ }}
+
+ n, err := s.rpcSendMsg(req)
+ return int(n), err
+}
+
+func (s *Stack) rpcRecvMsg(req *pb.SyscallRequest_Recvmsg) (*pb.RecvmsgResponse_ResultPayload, *syserr.Error) {
+ id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */)
+ <-c
+
+ res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Recvmsg).Recvmsg.Result
+ if e, ok := res.(*pb.RecvmsgResponse_ErrorNumber); ok {
+ return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber))
+ }
+
+ return res.(*pb.RecvmsgResponse_Payload).Payload, nil
+}
+
+func (s *Stack) recvMsg(fd, l, flags uint32) ([]byte, *syserr.Error) {
+ req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{
+ Fd: fd,
+ Length: l,
+ Sender: false,
+ Trunc: flags&linux.MSG_TRUNC != 0,
+ Peek: flags&linux.MSG_PEEK != 0,
+ }}
+
+ res, err := s.rpcRecvMsg(req)
+ if err != nil {
+ return nil, err
+ }
+ return res.Data, nil
+}
+
+func (s *Stack) netlinkRequest(proto, family int) ([]byte, error) {
+ fd, err := s.getNetlinkFd()
+ if err != nil {
+ return nil, err.ToError()
+ }
+ defer s.closeNetlinkFd(fd)
+
+ lsa := syscall.SockaddrNetlink{Family: syscall.AF_NETLINK}
+ b := binary.Marshal(nil, usermem.ByteOrder, &lsa)
+ if err := s.bindNetlinkFd(fd, b); err != nil {
+ return nil, err.ToError()
+ }
+
+ wb := newNetlinkRouteRequest(proto, 1, family)
+ _, err = s.sendMsg(fd, wb, b, 0)
+ if err != nil {
+ return nil, err.ToError()
+ }
+
+ var tab []byte
+done:
+ for {
+ rb, err := s.recvMsg(fd, uint32(syscall.Getpagesize()), 0)
+ nr := len(rb)
+ if err != nil {
+ return nil, err.ToError()
+ }
+
+ if nr < syscall.NLMSG_HDRLEN {
+ return nil, syserr.ErrInvalidArgument.ToError()
+ }
+
+ tab = append(tab, rb...)
+ msgs, e := syscall.ParseNetlinkMessage(rb)
+ if e != nil {
+ return nil, e
+ }
+
+ for _, m := range msgs {
+ if m.Header.Type == syscall.NLMSG_DONE {
+ break done
+ }
+ if m.Header.Type == syscall.NLMSG_ERROR {
+ return nil, syserr.ErrInvalidArgument.ToError()
+ }
+ }
+ }
+
+ return tab, nil
+}
+
+// DoNetlinkRouteRequest returns routing information base, also known as RIB,
+// which consists of network facility information, states and parameters.
+func (s *Stack) DoNetlinkRouteRequest(req int) ([]syscall.NetlinkMessage, error) {
+ data, err := s.netlinkRequest(req, syscall.AF_UNSPEC)
+ if err != nil {
+ return nil, err
+ }
+ return syscall.ParseNetlinkMessage(data)
+}
diff --git a/pkg/sentry/socket/rpcinet/syscall_rpc.proto b/pkg/sentry/socket/rpcinet/syscall_rpc.proto
new file mode 100644
index 000000000..b845b1bce
--- /dev/null
+++ b/pkg/sentry/socket/rpcinet/syscall_rpc.proto
@@ -0,0 +1,351 @@
+syntax = "proto3";
+
+// package syscall_rpc is a set of networking related system calls that can be
+// forwarded to a socket gofer.
+//
+// TODO: Document individual RPCs.
+package syscall_rpc;
+
+message SendmsgRequest {
+ uint32 fd = 1;
+ bytes data = 2;
+ bytes address = 3;
+ bool more = 4;
+ bool end_of_record = 5;
+}
+
+message SendmsgResponse {
+ oneof result {
+ uint32 error_number = 1;
+ uint32 length = 2;
+ }
+}
+
+message IOCtlRequest {
+ uint32 fd = 1;
+ uint32 cmd = 2;
+ uint64 arg = 3;
+}
+
+message IOCtlResponse {
+ oneof result {
+ uint32 error_number = 1;
+ uint64 value = 2;
+ }
+}
+
+message RecvmsgRequest {
+ uint32 fd = 1;
+ uint32 length = 2;
+ bool sender = 3;
+ bool peek = 4;
+ bool trunc = 5;
+}
+
+message OpenRequest {
+ bytes path = 1;
+ uint32 flags = 2;
+ uint32 mode = 3;
+}
+
+message OpenResponse {
+ oneof result {
+ uint32 error_number = 1;
+ uint32 fd = 2;
+ }
+}
+
+message ReadRequest {
+ uint32 fd = 1;
+ uint32 length = 2;
+}
+
+message ReadResponse {
+ oneof result {
+ uint32 error_number = 1;
+ bytes data = 2;
+ }
+}
+
+message ReadFileRequest {
+ string path = 1;
+}
+
+message ReadFileResponse {
+ oneof result {
+ uint32 error_number = 1;
+ bytes data = 2;
+ }
+}
+
+message WriteRequest {
+ uint32 fd = 1;
+ bytes data = 2;
+}
+
+message WriteResponse {
+ oneof result {
+ uint32 error_number = 1;
+ uint32 length = 2;
+ }
+}
+
+message WriteFileRequest {
+ string path = 1;
+ bytes content = 2;
+}
+
+message WriteFileResponse {
+ uint32 error_number = 1;
+ uint32 written = 2;
+}
+
+message AddressResponse {
+ bytes address = 1;
+ uint32 length = 2;
+}
+
+message RecvmsgResponse {
+ message ResultPayload {
+ bytes data = 1;
+ AddressResponse address = 2;
+ uint32 length = 3;
+ }
+ oneof result {
+ uint32 error_number = 1;
+ ResultPayload payload = 2;
+ }
+}
+
+message BindRequest {
+ uint32 fd = 1;
+ bytes address = 2;
+}
+
+message BindResponse {
+ uint32 error_number = 1;
+}
+
+message AcceptRequest {
+ uint32 fd = 1;
+ bool peer = 2;
+ int64 flags = 3;
+}
+
+message AcceptResponse {
+ message ResultPayload {
+ uint32 fd = 1;
+ AddressResponse address = 2;
+ }
+ oneof result {
+ uint32 error_number = 1;
+ ResultPayload payload = 2;
+ }
+}
+
+message ConnectRequest {
+ uint32 fd = 1;
+ bytes address = 2;
+}
+
+message ConnectResponse {
+ uint32 error_number = 1;
+}
+
+message ListenRequest {
+ uint32 fd = 1;
+ int64 backlog = 2;
+}
+
+message ListenResponse {
+ uint32 error_number = 1;
+}
+
+message ShutdownRequest {
+ uint32 fd = 1;
+ int64 how = 2;
+}
+
+message ShutdownResponse {
+ uint32 error_number = 1;
+}
+
+message CloseRequest {
+ uint32 fd = 1;
+}
+
+message CloseResponse {
+ uint32 error_number = 1;
+}
+
+message GetSockOptRequest {
+ uint32 fd = 1;
+ int64 level = 2;
+ int64 name = 3;
+ uint32 length = 4;
+}
+
+message GetSockOptResponse {
+ oneof result {
+ uint32 error_number = 1;
+ bytes opt = 2;
+ }
+}
+
+message SetSockOptRequest {
+ uint32 fd = 1;
+ int64 level = 2;
+ int64 name = 3;
+ bytes opt = 4;
+}
+
+message SetSockOptResponse {
+ uint32 error_number = 1;
+}
+
+message GetSockNameRequest {
+ uint32 fd = 1;
+}
+
+message GetSockNameResponse {
+ oneof result {
+ uint32 error_number = 1;
+ AddressResponse address = 2;
+ }
+}
+
+message GetPeerNameRequest {
+ uint32 fd = 1;
+}
+
+message GetPeerNameResponse {
+ oneof result {
+ uint32 error_number = 1;
+ AddressResponse address = 2;
+ }
+}
+
+message SocketRequest {
+ int64 family = 1;
+ int64 type = 2;
+ int64 protocol = 3;
+}
+
+message SocketResponse {
+ oneof result {
+ uint32 error_number = 1;
+ uint32 fd = 2;
+ }
+}
+
+message EpollWaitRequest {
+ uint32 fd = 1;
+ uint32 num_events = 2;
+ sint64 msec = 3;
+}
+
+message EpollEvent {
+ uint32 fd = 1;
+ uint32 events = 2;
+}
+
+message EpollEvents {
+ repeated EpollEvent events = 1;
+}
+
+message EpollWaitResponse {
+ oneof result {
+ uint32 error_number = 1;
+ EpollEvents events = 2;
+ }
+}
+
+message EpollCtlRequest {
+ uint32 epfd = 1;
+ int64 op = 2;
+ uint32 fd = 3;
+ EpollEvent event = 4;
+}
+
+message EpollCtlResponse {
+ uint32 error_number = 1;
+}
+
+message EpollCreate1Request {
+ int64 flag = 1;
+}
+
+message EpollCreate1Response {
+ oneof result {
+ uint32 error_number = 1;
+ uint32 fd = 2;
+ }
+}
+
+message PollRequest {
+ uint32 fd = 1;
+ uint32 events = 2;
+}
+
+message PollResponse {
+ oneof result {
+ uint32 error_number = 1;
+ uint32 events = 2;
+ }
+}
+
+message SyscallRequest {
+ oneof args {
+ SocketRequest socket = 1;
+ SendmsgRequest sendmsg = 2;
+ RecvmsgRequest recvmsg = 3;
+ BindRequest bind = 4;
+ AcceptRequest accept = 5;
+ ConnectRequest connect = 6;
+ ListenRequest listen = 7;
+ ShutdownRequest shutdown = 8;
+ CloseRequest close = 9;
+ GetSockOptRequest get_sock_opt = 10;
+ SetSockOptRequest set_sock_opt = 11;
+ GetSockNameRequest get_sock_name = 12;
+ GetPeerNameRequest get_peer_name = 13;
+ EpollWaitRequest epoll_wait = 14;
+ EpollCtlRequest epoll_ctl = 15;
+ EpollCreate1Request epoll_create1 = 16;
+ PollRequest poll = 17;
+ ReadRequest read = 18;
+ WriteRequest write = 19;
+ OpenRequest open = 20;
+ IOCtlRequest ioctl = 21;
+ WriteFileRequest write_file = 22;
+ ReadFileRequest read_file = 23;
+ }
+}
+
+message SyscallResponse {
+ oneof result {
+ SocketResponse socket = 1;
+ SendmsgResponse sendmsg = 2;
+ RecvmsgResponse recvmsg = 3;
+ BindResponse bind = 4;
+ AcceptResponse accept = 5;
+ ConnectResponse connect = 6;
+ ListenResponse listen = 7;
+ ShutdownResponse shutdown = 8;
+ CloseResponse close = 9;
+ GetSockOptResponse get_sock_opt = 10;
+ SetSockOptResponse set_sock_opt = 11;
+ GetSockNameResponse get_sock_name = 12;
+ GetPeerNameResponse get_peer_name = 13;
+ EpollWaitResponse epoll_wait = 14;
+ EpollCtlResponse epoll_ctl = 15;
+ EpollCreate1Response epoll_create1 = 16;
+ PollResponse poll = 17;
+ ReadResponse read = 18;
+ WriteResponse write = 19;
+ OpenResponse open = 20;
+ IOCtlResponse ioctl = 21;
+ WriteFileResponse write_file = 22;
+ ReadFileResponse read_file = 23;
+ }
+}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
new file mode 100644
index 000000000..be3026bfa
--- /dev/null
+++ b/pkg/sentry/socket/socket.go
@@ -0,0 +1,205 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package socket provides the interfaces that need to be provided by socket
+// implementations and providers, as well as per family demultiplexing of socket
+// creation.
+package socket
+
+import (
+ "fmt"
+ "sync/atomic"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix"
+)
+
+// Socket is the interface containing socket syscalls used by the syscall layer
+// to redirect them to the appropriate implementation.
+type Socket interface {
+ fs.FileOperations
+
+ // Connect implements the connect(2) linux syscall.
+ Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error
+
+ // Accept implements the accept4(2) linux syscall.
+ // Returns fd, real peer address length and error. Real peer address
+ // length is only set if len(peer) > 0.
+ Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error)
+
+ // Bind implements the bind(2) linux syscall.
+ Bind(t *kernel.Task, sockaddr []byte) *syserr.Error
+
+ // Listen implements the listen(2) linux syscall.
+ Listen(t *kernel.Task, backlog int) *syserr.Error
+
+ // Shutdown implements the shutdown(2) linux syscall.
+ Shutdown(t *kernel.Task, how int) *syserr.Error
+
+ // GetSockOpt implements the getsockopt(2) linux syscall.
+ GetSockOpt(t *kernel.Task, level int, name int, outLen int) (interface{}, *syserr.Error)
+
+ // SetSockOpt implements the setsockopt(2) linux syscall.
+ SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error
+
+ // GetSockName implements the getsockname(2) linux syscall.
+ //
+ // addrLen is the address length to be returned to the application, not
+ // necessarily the actual length of the address.
+ GetSockName(t *kernel.Task) (addr interface{}, addrLen uint32, err *syserr.Error)
+
+ // GetPeerName implements the getpeername(2) linux syscall.
+ //
+ // addrLen is the address length to be returned to the application, not
+ // necessarily the actual length of the address.
+ GetPeerName(t *kernel.Task) (addr interface{}, addrLen uint32, err *syserr.Error)
+
+ // RecvMsg implements the recvmsg(2) linux syscall.
+ //
+ // senderAddrLen is the address length to be returned to the application,
+ // not necessarily the actual length of the address.
+ RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages unix.ControlMessages, err *syserr.Error)
+
+ // SendMsg implements the sendmsg(2) linux syscall. SendMsg does not take
+ // ownership of the ControlMessage on error.
+ SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages unix.ControlMessages) (n int, err *syserr.Error)
+
+ // SetRecvTimeout sets the timeout (in ns) for recv operations. Zero means
+ // no timeout.
+ SetRecvTimeout(nanoseconds int64)
+
+ // RecvTimeout gets the current timeout (in ns) for recv operations. Zero
+ // means no timeout.
+ RecvTimeout() int64
+}
+
+// Provider is the interface implemented by providers of sockets for specific
+// address families (e.g., AF_INET).
+type Provider interface {
+ // Socket creates a new socket.
+ //
+ // If a nil Socket _and_ a nil error is returned, it means that the
+ // protocol is not supported. A non-nil error should only be returned
+ // if the protocol is supported, but an error occurs during creation.
+ Socket(t *kernel.Task, stype unix.SockType, protocol int) (*fs.File, *syserr.Error)
+
+ // Pair creates a pair of connected sockets.
+ //
+ // See Socket for error information.
+ Pair(t *kernel.Task, stype unix.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error)
+}
+
+// families holds a map of all known address families and their providers.
+var families = make(map[int][]Provider)
+
+// RegisterProvider registers the provider of a given address family so that
+// sockets of that type can be created via socket() and/or socketpair()
+// syscalls.
+func RegisterProvider(family int, provider Provider) {
+ families[family] = append(families[family], provider)
+}
+
+// New creates a new socket with the given family, type and protocol.
+func New(t *kernel.Task, family int, stype unix.SockType, protocol int) (*fs.File, *syserr.Error) {
+ for _, p := range families[family] {
+ s, err := p.Socket(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+ if s != nil {
+ return s, nil
+ }
+ }
+
+ return nil, syserr.ErrAddressFamilyNotSupported
+}
+
+// Pair creates a new connected socket pair with the given family, type and
+// protocol.
+func Pair(t *kernel.Task, family int, stype unix.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+ providers, ok := families[family]
+ if !ok {
+ return nil, nil, syserr.ErrAddressFamilyNotSupported
+ }
+
+ for _, p := range providers {
+ s, t, err := p.Pair(t, stype, protocol)
+ if err != nil {
+ return nil, nil, err
+ }
+ if s != nil && t != nil {
+ return s, t, nil
+ }
+ }
+
+ return nil, nil, syserr.ErrSocketNotSupported
+}
+
+// NewDirent returns a sockfs fs.Dirent that resides on device d.
+func NewDirent(ctx context.Context, d *device.Device) *fs.Dirent {
+ ino := d.NextIno()
+ // There is no real filesystem backing this pipe, so we pass in a nil
+ // Filesystem.
+ inode := fs.NewInode(fsutil.NewSimpleInodeOperations(fsutil.InodeSimpleAttributes{
+ FSType: linux.SOCKFS_MAGIC,
+ UAttr: fs.WithCurrentTime(ctx, fs.UnstableAttr{
+ Owner: fs.FileOwnerFromContext(ctx),
+ Perms: fs.FilePermissions{
+ User: fs.PermMask{Read: true, Write: true},
+ },
+ Links: 1,
+ }),
+ }), fs.NewNonCachingMountSource(nil, fs.MountSourceFlags{}), fs.StableAttr{
+ Type: fs.Socket,
+ DeviceID: d.DeviceID(),
+ InodeID: ino,
+ BlockSize: usermem.PageSize,
+ })
+
+ // Dirent name matches net/socket.c:sockfs_dname.
+ return fs.NewDirent(inode, fmt.Sprintf("socket:[%d]", ino))
+}
+
+// ReceiveTimeout stores a timeout for receive calls.
+//
+// It is meant to be embedded into Socket implementations to help satisfy the
+// interface.
+//
+// Care must be taken when copying ReceiveTimeout as it contains atomic
+// variables.
+type ReceiveTimeout struct {
+ // ns is length of the timeout in nanoseconds.
+ //
+ // ns must be accessed atomically.
+ ns int64
+}
+
+// SetRecvTimeout implements Socket.SetRecvTimeout.
+func (rt *ReceiveTimeout) SetRecvTimeout(nanoseconds int64) {
+ atomic.StoreInt64(&rt.ns, nanoseconds)
+}
+
+// RecvTimeout implements Socket.RecvTimeout.
+func (rt *ReceiveTimeout) RecvTimeout() int64 {
+ return atomic.LoadInt64(&rt.ns)
+}
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
new file mode 100644
index 000000000..1ec6eb7ed
--- /dev/null
+++ b/pkg/sentry/socket/unix/BUILD
@@ -0,0 +1,48 @@
+package(licenses = ["notice"]) # Apache 2.0
+
+load("@io_bazel_rules_go//go:def.bzl", "go_library")
+load("//tools/go_stateify:defs.bzl", "go_stateify")
+
+go_stateify(
+ name = "unix_state",
+ srcs = [
+ "unix.go",
+ ],
+ out = "unix_state.go",
+ package = "unix",
+)
+
+go_library(
+ name = "unix",
+ srcs = [
+ "device.go",
+ "io.go",
+ "unix.go",
+ "unix_state.go",
+ ],
+ importpath = "gvisor.googlesource.com/gvisor/pkg/sentry/socket/unix",
+ visibility = ["//pkg/sentry:internal"],
+ deps = [
+ "//pkg/abi/linux",
+ "//pkg/refs",
+ "//pkg/sentry/arch",
+ "//pkg/sentry/context",
+ "//pkg/sentry/device",
+ "//pkg/sentry/fs",
+ "//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/kernel",
+ "//pkg/sentry/kernel/kdefs",
+ "//pkg/sentry/kernel/time",
+ "//pkg/sentry/safemem",
+ "//pkg/sentry/socket",
+ "//pkg/sentry/socket/control",
+ "//pkg/sentry/socket/epsocket",
+ "//pkg/sentry/usermem",
+ "//pkg/state",
+ "//pkg/syserr",
+ "//pkg/syserror",
+ "//pkg/tcpip",
+ "//pkg/tcpip/transport/unix",
+ "//pkg/waiter",
+ ],
+)
diff --git a/pkg/sentry/socket/unix/device.go b/pkg/sentry/socket/unix/device.go
new file mode 100644
index 000000000..e8bcc7a9f
--- /dev/null
+++ b/pkg/sentry/socket/unix/device.go
@@ -0,0 +1,20 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import "gvisor.googlesource.com/gvisor/pkg/sentry/device"
+
+// unixSocketDevice is the unix socket virtual device.
+var unixSocketDevice = device.NewAnonDevice()
diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go
new file mode 100644
index 000000000..0ca2e35d0
--- /dev/null
+++ b/pkg/sentry/socket/unix/io.go
@@ -0,0 +1,88 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import (
+ "gvisor.googlesource.com/gvisor/pkg/sentry/safemem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix"
+)
+
+// EndpointWriter implements safemem.Writer that writes to a unix.Endpoint.
+//
+// EndpointWriter is not thread-safe.
+type EndpointWriter struct {
+ // Endpoint is the unix.Endpoint to write to.
+ Endpoint unix.Endpoint
+
+ // Control is the control messages to send.
+ Control unix.ControlMessages
+
+ // To is the endpoint to send to. May be nil.
+ To unix.BoundEndpoint
+}
+
+// WriteFromBlocks implements safemem.Writer.WriteFromBlocks.
+func (w *EndpointWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) {
+ return safemem.FromVecWriterFunc{func(bufs [][]byte) (int64, error) {
+ n, err := w.Endpoint.SendMsg(bufs, w.Control, w.To)
+ if err != nil {
+ return int64(n), syserr.TranslateNetstackError(err).ToError()
+ }
+ return int64(n), nil
+ }}.WriteFromBlocks(srcs)
+}
+
+// EndpointReader implements safemem.Reader that reads from a unix.Endpoint.
+//
+// EndpointReader is not thread-safe.
+type EndpointReader struct {
+ // Endpoint is the unix.Endpoint to read from.
+ Endpoint unix.Endpoint
+
+ // Creds indicates if credential control messages are requested.
+ Creds bool
+
+ // NumRights is the number of SCM_RIGHTS FDs requested.
+ NumRights uintptr
+
+ // Peek indicates that the data should not be consumed from the
+ // endpoint.
+ Peek bool
+
+ // MsgSize is the size of the message that was read from. For stream
+ // sockets, it is the amount read.
+ MsgSize uintptr
+
+ // From, if not nil, will be set with the address read from.
+ From *tcpip.FullAddress
+
+ // Control contains the received control messages.
+ Control unix.ControlMessages
+}
+
+// ReadToBlocks implements safemem.Reader.ReadToBlocks.
+func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) {
+ return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) {
+ n, ms, c, err := r.Endpoint.RecvMsg(bufs, r.Creds, r.NumRights, r.Peek, r.From)
+ r.Control = c
+ r.MsgSize = ms
+ if err != nil {
+ return int64(n), syserr.TranslateNetstackError(err).ToError()
+ }
+ return int64(n), nil
+ }}.ReadToBlocks(dsts)
+}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
new file mode 100644
index 000000000..a4b414851
--- /dev/null
+++ b/pkg/sentry/socket/unix/unix.go
@@ -0,0 +1,571 @@
+// Copyright 2018 Google Inc.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package unix provides an implementation of the socket.Socket interface for
+// the AF_UNIX protocol family.
+package unix
+
+import (
+ "strings"
+ "syscall"
+
+ "gvisor.googlesource.com/gvisor/pkg/abi/linux"
+ "gvisor.googlesource.com/gvisor/pkg/refs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/arch"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/context"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/fs/fsutil"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/kdefs"
+ ktime "gvisor.googlesource.com/gvisor/pkg/sentry/kernel/time"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/control"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/socket/epsocket"
+ "gvisor.googlesource.com/gvisor/pkg/sentry/usermem"
+ "gvisor.googlesource.com/gvisor/pkg/syserr"
+ "gvisor.googlesource.com/gvisor/pkg/syserror"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip"
+ "gvisor.googlesource.com/gvisor/pkg/tcpip/transport/unix"
+ "gvisor.googlesource.com/gvisor/pkg/waiter"
+)
+
+// SocketOperations is a Unix socket. It is similar to an epsocket, except it is backed
+// by a unix.Endpoint instead of a tcpip.Endpoint.
+type SocketOperations struct {
+ refs.AtomicRefCount
+ socket.ReceiveTimeout
+ fsutil.PipeSeek `state:"nosave"`
+ fsutil.NotDirReaddir `state:"nosave"`
+ fsutil.NoFsync `state:"nosave"`
+ fsutil.NoopFlush `state:"nosave"`
+ fsutil.NoMMap `state:"nosave"`
+ ep unix.Endpoint
+}
+
+// New creates a new unix socket.
+func New(ctx context.Context, endpoint unix.Endpoint) *fs.File {
+ dirent := socket.NewDirent(ctx, unixSocketDevice)
+ return NewWithDirent(ctx, dirent, endpoint, fs.FileFlags{Read: true, Write: true})
+}
+
+// NewWithDirent creates a new unix socket using an existing dirent.
+func NewWithDirent(ctx context.Context, d *fs.Dirent, ep unix.Endpoint, flags fs.FileFlags) *fs.File {
+ return fs.NewFile(ctx, d, flags, &SocketOperations{
+ ep: ep,
+ })
+}
+
+// DecRef implements RefCounter.DecRef.
+func (s *SocketOperations) DecRef() {
+ s.DecRefWithDestructor(func() {
+ s.ep.Close()
+ })
+}
+
+// Release implemements fs.FileOperations.Release.
+func (s *SocketOperations) Release() {
+ // Release only decrements a reference on s because s may be referenced in
+ // the abstract socket namespace.
+ s.DecRef()
+}
+
+// Endpoint extracts the unix.Endpoint.
+func (s *SocketOperations) Endpoint() unix.Endpoint {
+ return s.ep
+}
+
+// extractPath extracts and validates the address.
+func extractPath(sockaddr []byte) (string, *syserr.Error) {
+ addr, err := epsocket.GetAddress(linux.AF_UNIX, sockaddr)
+ if err != nil {
+ return "", err
+ }
+
+ // The address is trimmed by GetAddress.
+ p := string(addr.Addr)
+ if p == "" {
+ // Not allowed.
+ return "", syserr.ErrInvalidArgument
+ }
+ if p[len(p)-1] == '/' {
+ // Weird, they tried to bind '/a/b/c/'?
+ return "", syserr.ErrIsDir
+ }
+
+ return p, nil
+}
+
+// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
+// a unix.Endpoint.
+func (s *SocketOperations) GetPeerName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ addr, err := s.ep.GetRemoteAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr)
+ return a, l, nil
+}
+
+// GetSockName implements the linux syscall getsockname(2) for sockets backed by
+// a unix.Endpoint.
+func (s *SocketOperations) GetSockName(t *kernel.Task) (interface{}, uint32, *syserr.Error) {
+ addr, err := s.ep.GetLocalAddress()
+ if err != nil {
+ return nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ a, l := epsocket.ConvertAddress(linux.AF_UNIX, addr)
+ return a, l, nil
+}
+
+// Ioctl implements fs.FileOperations.Ioctl.
+func (s *SocketOperations) Ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return epsocket.Ioctl(ctx, s.ep, io, args)
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// a unix.Endpoint.
+func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name, outLen int) (interface{}, *syserr.Error) {
+ return epsocket.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
+}
+
+// Listen implements the linux syscall listen(2) for sockets backed by
+// a unix.Endpoint.
+func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+ return syserr.TranslateNetstackError(s.ep.Listen(backlog))
+}
+
+// blockingAccept implements a blocking version of accept(2), that is, if no
+// connections are ready to be accept, it will block until one becomes ready.
+func (s *SocketOperations) blockingAccept(t *kernel.Task) (unix.Endpoint, *syserr.Error) {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ // Try to accept the connection; if it fails, then wait until we get a
+ // notification.
+ for {
+ if ep, err := s.ep.Accept(); err != tcpip.ErrWouldBlock {
+ return ep, syserr.TranslateNetstackError(err)
+ }
+
+ if err := t.Block(ch); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ }
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// a unix.Endpoint.
+func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (kdefs.FD, interface{}, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, err := s.ep.Accept()
+ if err != nil {
+ if err != tcpip.ErrWouldBlock || !blocking {
+ return 0, nil, 0, syserr.TranslateNetstackError(err)
+ }
+
+ var err *syserr.Error
+ ep, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ ns := New(t, ep)
+ defer ns.DecRef()
+
+ if flags&linux.SOCK_NONBLOCK != 0 {
+ flags := ns.Flags()
+ flags.NonBlocking = true
+ ns.SetFlags(flags.Settable())
+ }
+
+ var addr interface{}
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer.
+ var err *syserr.Error
+ addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fdFlags := kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ }
+ fd, e := t.FDMap().NewFDFrom(0, ns, fdFlags, t.ThreadGroup().Limits())
+ if e != nil {
+ return 0, nil, 0, syserr.FromError(e)
+ }
+
+ return fd, addr, addrLen, nil
+}
+
+// Bind implements the linux syscall bind(2) for unix sockets.
+func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ p, e := extractPath(sockaddr)
+ if e != nil {
+ return e
+ }
+
+ bep, ok := s.ep.(unix.BoundEndpoint)
+ if !ok {
+ // This socket can't be bound.
+ return syserr.ErrInvalidArgument
+ }
+
+ return syserr.TranslateNetstackError(s.ep.Bind(tcpip.FullAddress{Addr: tcpip.Address(p)}, func() *tcpip.Error {
+ // Is it abstract?
+ if p[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return tcpip.ErrInvalidEndpointState
+ }
+ if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil {
+ // tcpip.ErrPortInUse corresponds to EADDRINUSE.
+ return tcpip.ErrPortInUse
+ }
+ } else {
+ // The parent and name.
+ var d *fs.Dirent
+ var name string
+
+ cwd := t.FSContext().WorkingDirectory()
+ defer cwd.DecRef()
+
+ // Is there no slash at all?
+ if !strings.Contains(p, "/") {
+ d = cwd
+ name = p
+ } else {
+ root := t.FSContext().RootDirectory()
+ defer root.DecRef()
+ // Find the last path component, we know that something follows
+ // that final slash, otherwise extractPath() would have failed.
+ lastSlash := strings.LastIndex(p, "/")
+ subPath := p[:lastSlash]
+ if subPath == "" {
+ // Fix up subpath in case file is in root.
+ subPath = "/"
+ }
+ var err error
+ d, err = t.MountNamespace().FindInode(t, root, cwd, subPath, fs.DefaultTraversalLimit)
+ if err != nil {
+ // No path available.
+ return tcpip.ErrNoSuchFile
+ }
+ defer d.DecRef()
+ name = p[lastSlash+1:]
+ }
+
+ // Create the socket.
+ if err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}}); err != nil {
+ return tcpip.ErrPortInUse
+ }
+ }
+
+ return nil
+ }))
+}
+
+// extractEndpoint retrieves the unix.BoundEndpoint associated with a Unix
+// socket path. The Release must be called on the unix.BoundEndpoint when the
+// caller is done with it.
+func extractEndpoint(t *kernel.Task, sockaddr []byte) (unix.BoundEndpoint, *syserr.Error) {
+ path, err := extractPath(sockaddr)
+ if err != nil {
+ return nil, err
+ }
+
+ // Is it abstract?
+ if path[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ ep := t.AbstractSockets().BoundEndpoint(path[1:])
+ if ep == nil {
+ // No socket found.
+ return nil, syserr.ErrConnectionRefused
+ }
+
+ return ep, nil
+ }
+
+ // Find the node in the filesystem.
+ root := t.FSContext().RootDirectory()
+ cwd := t.FSContext().WorkingDirectory()
+ d, e := t.MountNamespace().FindInode(t, root, cwd, path, fs.DefaultTraversalLimit)
+ cwd.DecRef()
+ root.DecRef()
+ if e != nil {
+ return nil, syserr.FromError(e)
+ }
+
+ // Extract the endpoint if one is there.
+ ep := d.Inode.BoundEndpoint(path)
+ d.DecRef()
+ if ep == nil {
+ // No socket!
+ return nil, syserr.ErrConnectionRefused
+ }
+
+ return ep, nil
+}
+
+// Connect implements the linux syscall connect(2) for unix sockets.
+func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+ ep, err := extractEndpoint(t, sockaddr)
+ if err != nil {
+ return err
+ }
+ defer ep.Release()
+
+ // Connect the server endpoint.
+ return syserr.TranslateNetstackError(s.ep.Connect(ep))
+}
+
+// Writev implements fs.FileOperations.Write.
+func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
+ t := kernel.TaskFromContext(ctx)
+ ctrl := control.New(t, s.ep, nil)
+
+ if src.NumBytes() == 0 {
+ nInt, tcpipError := s.ep.SendMsg([][]byte{}, ctrl, nil)
+ return int64(nInt), syserr.TranslateNetstackError(tcpipError).ToError()
+ }
+
+ return src.CopyInTo(ctx, &EndpointWriter{
+ Endpoint: s.ep,
+ Control: ctrl,
+ To: nil,
+ })
+}
+
+// SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by
+// a unix.Endpoint.
+func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, controlMessages unix.ControlMessages) (int, *syserr.Error) {
+ w := EndpointWriter{
+ Endpoint: s.ep,
+ Control: controlMessages,
+ To: nil,
+ }
+ if len(to) > 0 {
+ ep, err := extractEndpoint(t, to)
+ if err != nil {
+ return 0, err
+ }
+ defer ep.Release()
+ w.To = ep
+ }
+
+ if n, err := src.CopyInTo(t, &w); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ return int(n), syserr.FromError(err)
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventOut)
+ defer s.EventUnregister(&e)
+
+ for {
+ if n, err := src.CopyInTo(t, &w); err != syserror.ErrWouldBlock {
+ return int(n), syserr.FromError(err)
+ }
+
+ if err := t.Block(ch); err != nil {
+ return 0, syserr.FromError(err)
+ }
+ }
+}
+
+// Passcred implements unix.Credentialer.Passcred.
+func (s *SocketOperations) Passcred() bool {
+ return s.ep.Passcred()
+}
+
+// ConnectedPasscred implements unix.Credentialer.ConnectedPasscred.
+func (s *SocketOperations) ConnectedPasscred() bool {
+ return s.ep.ConnectedPasscred()
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.ep.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SocketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.ep.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SocketOperations) EventUnregister(e *waiter.Entry) {
+ s.ep.EventUnregister(e)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// a unix.Endpoint.
+func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ return epsocket.SetSockOpt(t, s, s.ep, level, name, optVal)
+}
+
+// Shutdown implements the linux syscall shutdown(2) for sockets backed by
+// a unix.Endpoint.
+func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+ f, err := epsocket.ConvertShutdown(how)
+ if err != nil {
+ return err
+ }
+
+ // Issue shutdown request.
+ return syserr.TranslateNetstackError(s.ep.Shutdown(f))
+}
+
+// Read implements fs.FileOperations.Read.
+func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) {
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &EndpointReader{
+ Endpoint: s.ep,
+ NumRights: 0,
+ Peek: false,
+ From: nil,
+ })
+}
+
+// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
+// a unix.Endpoint.
+func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, senderAddr interface{}, senderAddrLen uint32, controlMessages unix.ControlMessages, err *syserr.Error) {
+ trunc := flags&linux.MSG_TRUNC != 0
+ peek := flags&linux.MSG_PEEK != 0
+
+ // Calculate the number of FDs for which we have space and if we are
+ // requesting credentials.
+ var wantCreds bool
+ rightsLen := int(controlDataLen) - syscall.SizeofCmsghdr
+ if s.Passcred() {
+ // Credentials take priority if they are enabled and there is space.
+ wantCreds = rightsLen > 0
+ credLen := syscall.CmsgSpace(syscall.SizeofUcred)
+ rightsLen -= credLen
+ }
+ // FDs are 32 bit (4 byte) ints.
+ numRights := rightsLen / 4
+ if numRights < 0 {
+ numRights = 0
+ }
+
+ r := EndpointReader{
+ Endpoint: s.ep,
+ Creds: wantCreds,
+ NumRights: uintptr(numRights),
+ Peek: peek,
+ }
+ if senderRequested {
+ r.From = &tcpip.FullAddress{}
+ }
+ if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 {
+ var from interface{}
+ var fromLen uint32
+ if r.From != nil {
+ from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
+ }
+ if trunc {
+ n = int64(r.MsgSize)
+ }
+ return int(n), from, fromLen, r.Control, syserr.FromError(err)
+ }
+
+ // We'll have to block. Register for notification and keep trying to
+ // send all the data.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventIn)
+ defer s.EventUnregister(&e)
+
+ for {
+ if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock {
+ var from interface{}
+ var fromLen uint32
+ if r.From != nil {
+ from, fromLen = epsocket.ConvertAddress(linux.AF_UNIX, *r.From)
+ }
+ if trunc {
+ n = int64(r.MsgSize)
+ }
+ return int(n), from, fromLen, r.Control, syserr.FromError(err)
+ }
+
+ if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil {
+ if err == syserror.ETIMEDOUT {
+ return 0, nil, 0, unix.ControlMessages{}, syserr.ErrTryAgain
+ }
+ return 0, nil, 0, unix.ControlMessages{}, syserr.FromError(err)
+ }
+ }
+}
+
+// provider is a unix domain socket provider.
+type provider struct{}
+
+// Socket returns a new unix domain socket.
+func (*provider) Socket(t *kernel.Task, stype unix.SockType, protocol int) (*fs.File, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Create the endpoint and socket.
+ var ep unix.Endpoint
+ switch stype {
+ case linux.SOCK_DGRAM:
+ ep = unix.NewConnectionless()
+ case linux.SOCK_STREAM, linux.SOCK_SEQPACKET:
+ ep = unix.NewConnectioned(stype, t.Kernel())
+ default:
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ return New(t, ep), nil
+}
+
+// Pair creates a new pair of AF_UNIX connected sockets.
+func (*provider) Pair(t *kernel.Task, stype unix.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 {
+ return nil, nil, syserr.ErrInvalidArgument
+ }
+
+ switch stype {
+ case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
+ default:
+ return nil, nil, syserr.ErrInvalidArgument
+ }
+
+ // Create the endpoints and sockets.
+ ep1, ep2 := unix.NewPair(stype, t.Kernel())
+ s1 := New(t, ep1)
+ s2 := New(t, ep2)
+
+ return s1, s2, nil
+}
+
+func init() {
+ socket.RegisterProvider(linux.AF_UNIX, &provider{})
+}