diff options
Diffstat (limited to 'pkg/sentry/socket')
56 files changed, 5568 insertions, 3207 deletions
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD index 26176b10d..c0fd3425b 100644 --- a/pkg/sentry/socket/BUILD +++ b/pkg/sentry/socket/BUILD @@ -1,24 +1,25 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "socket", srcs = ["socket.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/binary", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", + "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/tcpip", + "//pkg/usermem", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD index 4a6e83a8b..ca16d0381 100644 --- a/pkg/sentry/socket/control/BUILD +++ b/pkg/sentry/socket/control/BUILD @@ -1,11 +1,13 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "control", - srcs = ["control.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/control", + srcs = [ + "control.go", + "control_vfs2.go", + ], imports = [ "gvisor.dev/gvisor/pkg/sentry/fs", ], @@ -13,12 +15,15 @@ go_library( deps = [ "//pkg/abi/linux", "//pkg/binary", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/fs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", + "//pkg/sentry/socket", "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", + "//pkg/sentry/vfs", "//pkg/syserror", + "//pkg/tcpip", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go index 4e95101b7..70ccf77a7 100644 --- a/pkg/sentry/socket/control/control.go +++ b/pkg/sentry/socket/control/control.go @@ -19,13 +19,15 @@ package control import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" ) const maxInt = int(^uint(0) >> 1) @@ -39,6 +41,8 @@ type SCMCredentials interface { Credentials(t *kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) } +// LINT.IfChange + // SCMRights represents a SCM_RIGHTS socket control message. type SCMRights interface { transport.RightsControlMessage @@ -64,7 +68,7 @@ func NewSCMRights(t *kernel.Task, fds []int32) (SCMRights, error) { for _, fd := range fds { file := t.GetFile(fd) if file == nil { - files.Release() + files.Release(t) return nil, syserror.EBADF } files = append(files, file) @@ -96,9 +100,9 @@ func (fs *RightsFiles) Clone() transport.RightsControlMessage { } // Release implements transport.RightsControlMessage.Release. -func (fs *RightsFiles) Release() { +func (fs *RightsFiles) Release(ctx context.Context) { for _, f := range *fs { - f.DecRef() + f.DecRef(ctx) } *fs = nil } @@ -111,7 +115,7 @@ func rightsFDs(t *kernel.Task, rights SCMRights, cloexec bool, max int) ([]int32 fd, err := t.NewFDFrom(0, files[0], kernel.FDFlags{ CloseOnExec: cloexec, }) - files[0].DecRef() + files[0].DecRef(t) files = files[1:] if err != nil { t.Warningf("Error inserting FD: %v", err) @@ -140,6 +144,8 @@ func PackRights(t *kernel.Task, rights SCMRights, cloexec bool, buf []byte, flag return putCmsg(buf, flags, linux.SCM_RIGHTS, align, fds) } +// LINT.ThenChange(./control_vfs2.go) + // scmCredentials represents an SCM_CREDENTIALS socket control message. // // +stateify savable @@ -188,21 +194,21 @@ func putUint32(buf []byte, n uint32) []byte { // putCmsg writes a control message header and as much data as will fit into // the unused capacity of a buffer. func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) { - space := AlignDown(cap(buf)-len(buf), 4) + space := binary.AlignDown(cap(buf)-len(buf), 4) // We can't write to space that doesn't exist, so if we are going to align // the available space, we must align down. // // align must be >= 4 and each data int32 is 4 bytes. The length of the - // header is already aligned, so if we align to the with of the data there + // header is already aligned, so if we align to the width of the data there // are two cases: // 1. The aligned length is less than the length of the header. The // unaligned length was also less than the length of the header, so we // can't write anything. // 2. The aligned length is greater than or equal to the length of the - // header. We can write the header plus zero or more datas. We can't write - // a partial int32, so the length of the message will be - // min(aligned length, header + datas). + // header. We can write the header plus zero or more bytes of data. We can't + // write a partial int32, so the length of the message will be + // min(aligned length, header + data). if space < linux.SizeOfControlMessageHeader { flags |= linux.MSG_CTRUNC return buf, flags @@ -239,12 +245,12 @@ func putCmsgStruct(buf []byte, msgLevel, msgType uint32, align uint, data interf buf = binary.Marshal(buf, usermem.ByteOrder, data) - // Check if we went over. + // If the control message data brought us over capacity, omit it. if cap(buf) != cap(ob) { return hdrBuf } - // Fix up length. + // Update control message length to include data. putUint64(ob, uint64(len(buf)-len(ob))) return alignSlice(buf, align) @@ -281,19 +287,9 @@ func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int return putCmsg(buf, flags, linux.SCM_CREDENTIALS, align, c) } -// AlignUp rounds a length up to an alignment. align must be a power of 2. -func AlignUp(length int, align uint) int { - return (length + int(align) - 1) & ^(int(align) - 1) -} - -// AlignDown rounds a down to an alignment. align must be a power of 2. -func AlignDown(length int, align uint) int { - return length & ^(int(align) - 1) -} - // alignSlice extends a slice's length (up to the capacity) to align it. func alignSlice(buf []byte, align uint) []byte { - aligned := AlignUp(len(buf), align) + aligned := binary.AlignUp(len(buf), align) if aligned > cap(buf) { // Linux allows unaligned data if there isn't room for alignment. // Since there isn't room for alignment, there isn't room for any @@ -320,35 +316,139 @@ func PackInq(t *kernel.Task, inq int32, buf []byte) []byte { buf, linux.SOL_TCP, linux.TCP_INQ, - 4, + t.Arch().Width(), inq, ) } +// PackTOS packs an IP_TOS socket control message. +func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte { + return putCmsgStruct( + buf, + linux.SOL_IP, + linux.IP_TOS, + t.Arch().Width(), + tos, + ) +} + +// PackTClass packs an IPV6_TCLASS socket control message. +func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte { + return putCmsgStruct( + buf, + linux.SOL_IPV6, + linux.IPV6_TCLASS, + t.Arch().Width(), + tClass, + ) +} + +// PackIPPacketInfo packs an IP_PKTINFO socket control message. +func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte { + var p linux.ControlMessageIPPacketInfo + p.NIC = int32(packetInfo.NIC) + copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr)) + copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr)) + + return putCmsgStruct( + buf, + linux.SOL_IP, + linux.IP_PKTINFO, + t.Arch().Width(), + p, + ) +} + +// PackControlMessages packs control messages into the given buffer. +// +// We skip control messages specific to Unix domain sockets. +// +// Note that some control messages may be truncated if they do not fit under +// the capacity of buf. +func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byte) []byte { + if cmsgs.IP.HasTimestamp { + buf = PackTimestamp(t, cmsgs.IP.Timestamp, buf) + } + + if cmsgs.IP.HasInq { + // In Linux, TCP_CM_INQ is added after SO_TIMESTAMP. + buf = PackInq(t, cmsgs.IP.Inq, buf) + } + + if cmsgs.IP.HasTOS { + buf = PackTOS(t, cmsgs.IP.TOS, buf) + } + + if cmsgs.IP.HasTClass { + buf = PackTClass(t, cmsgs.IP.TClass, buf) + } + + if cmsgs.IP.HasIPPacketInfo { + buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf) + } + + return buf +} + +// cmsgSpace is equivalent to CMSG_SPACE in Linux. +func cmsgSpace(t *kernel.Task, dataLen int) int { + return linux.SizeOfControlMessageHeader + binary.AlignUp(dataLen, t.Arch().Width()) +} + +// CmsgsSpace returns the number of bytes needed to fit the control messages +// represented in cmsgs. +func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int { + space := 0 + + if cmsgs.IP.HasTimestamp { + space += cmsgSpace(t, linux.SizeOfTimeval) + } + + if cmsgs.IP.HasInq { + space += cmsgSpace(t, linux.SizeOfControlMessageInq) + } + + if cmsgs.IP.HasTOS { + space += cmsgSpace(t, linux.SizeOfControlMessageTOS) + } + + if cmsgs.IP.HasTClass { + space += cmsgSpace(t, linux.SizeOfControlMessageTClass) + } + + return space +} + +// NewIPPacketInfo returns the IPPacketInfo struct. +func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo { + var p tcpip.IPPacketInfo + p.NIC = tcpip.NICID(packetInfo.NIC) + copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:]) + copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:]) + + return p +} + // Parse parses a raw socket control message into portable objects. -func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport.ControlMessages, error) { +func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) { var ( - fds linux.ControlMessageRights - haveCreds bool - creds linux.ControlMessageCredentials + cmsgs socket.ControlMessages + fds linux.ControlMessageRights ) for i := 0; i < len(buf); { if i+linux.SizeOfControlMessageHeader > len(buf) { - return transport.ControlMessages{}, syserror.EINVAL + return cmsgs, syserror.EINVAL } var h linux.ControlMessageHeader binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageHeader], usermem.ByteOrder, &h) if h.Length < uint64(linux.SizeOfControlMessageHeader) { - return transport.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, syserror.EINVAL } if h.Length > uint64(len(buf)-i) { - return transport.ControlMessages{}, syserror.EINVAL - } - if h.Level != linux.SOL_SOCKET { - return transport.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, syserror.EINVAL } i += linux.SizeOfControlMessageHeader @@ -358,59 +458,105 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (transport. // sizeof(long) in CMSG_ALIGN. width := t.Arch().Width() - switch h.Type { - case linux.SCM_RIGHTS: - rightsSize := AlignDown(length, linux.SizeOfControlMessageRight) - numRights := rightsSize / linux.SizeOfControlMessageRight - - if len(fds)+numRights > linux.SCM_MAX_FD { - return transport.ControlMessages{}, syserror.EINVAL + switch h.Level { + case linux.SOL_SOCKET: + switch h.Type { + case linux.SCM_RIGHTS: + rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight) + numRights := rightsSize / linux.SizeOfControlMessageRight + + if len(fds)+numRights > linux.SCM_MAX_FD { + return socket.ControlMessages{}, syserror.EINVAL + } + + for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight { + fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) + } + + i += binary.AlignUp(length, width) + + case linux.SCM_CREDENTIALS: + if length < linux.SizeOfControlMessageCredentials { + return socket.ControlMessages{}, syserror.EINVAL + } + + var creds linux.ControlMessageCredentials + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds) + scmCreds, err := NewSCMCredentials(t, creds) + if err != nil { + return socket.ControlMessages{}, err + } + cmsgs.Unix.Credentials = scmCreds + i += binary.AlignUp(length, width) + + default: + // Unknown message type. + return socket.ControlMessages{}, syserror.EINVAL } - - for j := i; j < i+rightsSize; j += linux.SizeOfControlMessageRight { - fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight]))) + case linux.SOL_IP: + switch h.Type { + case linux.IP_TOS: + if length < linux.SizeOfControlMessageTOS { + return socket.ControlMessages{}, syserror.EINVAL + } + cmsgs.IP.HasTOS = true + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS) + i += binary.AlignUp(length, width) + + case linux.IP_PKTINFO: + if length < linux.SizeOfControlMessageIPPacketInfo { + return socket.ControlMessages{}, syserror.EINVAL + } + + cmsgs.IP.HasIPPacketInfo = true + var packetInfo linux.ControlMessageIPPacketInfo + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) + + cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo) + i += binary.AlignUp(length, width) + + default: + return socket.ControlMessages{}, syserror.EINVAL } - - i += AlignUp(length, width) - - case linux.SCM_CREDENTIALS: - if length < linux.SizeOfControlMessageCredentials { - return transport.ControlMessages{}, syserror.EINVAL + case linux.SOL_IPV6: + switch h.Type { + case linux.IPV6_TCLASS: + if length < linux.SizeOfControlMessageTClass { + return socket.ControlMessages{}, syserror.EINVAL + } + cmsgs.IP.HasTClass = true + binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass) + i += binary.AlignUp(length, width) + + default: + return socket.ControlMessages{}, syserror.EINVAL } - - binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageCredentials], usermem.ByteOrder, &creds) - haveCreds = true - i += AlignUp(length, width) - default: - // Unknown message type. - return transport.ControlMessages{}, syserror.EINVAL + return socket.ControlMessages{}, syserror.EINVAL } } - var credentials SCMCredentials - if haveCreds { - var err error - if credentials, err = NewSCMCredentials(t, creds); err != nil { - return transport.ControlMessages{}, err - } - } else { - credentials = makeCreds(t, socketOrEndpoint) + if cmsgs.Unix.Credentials == nil { + cmsgs.Unix.Credentials = makeCreds(t, socketOrEndpoint) } - var rights SCMRights if len(fds) > 0 { - var err error - if rights, err = NewSCMRights(t, fds); err != nil { - return transport.ControlMessages{}, err + if kernel.VFS2Enabled { + rights, err := NewSCMRightsVFS2(t, fds) + if err != nil { + return socket.ControlMessages{}, err + } + cmsgs.Unix.Rights = rights + } else { + rights, err := NewSCMRights(t, fds) + if err != nil { + return socket.ControlMessages{}, err + } + cmsgs.Unix.Rights = rights } } - if credentials == nil && rights == nil { - return transport.ControlMessages{}, nil - } - - return transport.ControlMessages{Credentials: credentials, Rights: rights}, nil + return cmsgs, nil } func makeCreds(t *kernel.Task, socketOrEndpoint interface{}) SCMCredentials { @@ -432,6 +578,8 @@ func MakeCreds(t *kernel.Task) SCMCredentials { return &scmCredentials{t, tcred.EffectiveKUID, tcred.EffectiveKGID} } +// LINT.IfChange + // New creates default control messages if needed. func New(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRights) transport.ControlMessages { return transport.ControlMessages{ @@ -439,3 +587,5 @@ func New(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRights) transpo Rights: rights, } } + +// LINT.ThenChange(./control_vfs2.go) diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go new file mode 100644 index 000000000..d9621968c --- /dev/null +++ b/pkg/sentry/socket/control/control_vfs2.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package control + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" +) + +// SCMRightsVFS2 represents a SCM_RIGHTS socket control message. +type SCMRightsVFS2 interface { + transport.RightsControlMessage + + // Files returns up to max RightsFiles. + // + // Returned files are consumed and ownership is transferred to the caller. + // Subsequent calls to Files will return the next files. + Files(ctx context.Context, max int) (rf RightsFilesVFS2, truncated bool) +} + +// RightsFiles represents a SCM_RIGHTS socket control message. A reference is +// maintained for each vfs.FileDescription and is release either when an FD is created or +// when the Release method is called. +type RightsFilesVFS2 []*vfs.FileDescription + +// NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message +// representation using local sentry FDs. +func NewSCMRightsVFS2(t *kernel.Task, fds []int32) (SCMRightsVFS2, error) { + files := make(RightsFilesVFS2, 0, len(fds)) + for _, fd := range fds { + file := t.GetFileVFS2(fd) + if file == nil { + files.Release(t) + return nil, syserror.EBADF + } + files = append(files, file) + } + return &files, nil +} + +// Files implements SCMRights.Files. +func (fs *RightsFilesVFS2) Files(ctx context.Context, max int) (RightsFilesVFS2, bool) { + n := max + var trunc bool + if l := len(*fs); n > l { + n = l + } else if n < l { + trunc = true + } + rf := (*fs)[:n] + *fs = (*fs)[n:] + return rf, trunc +} + +// Clone implements transport.RightsControlMessage.Clone. +func (fs *RightsFilesVFS2) Clone() transport.RightsControlMessage { + nfs := append(RightsFilesVFS2(nil), *fs...) + for _, nf := range nfs { + nf.IncRef() + } + return &nfs +} + +// Release implements transport.RightsControlMessage.Release. +func (fs *RightsFilesVFS2) Release(ctx context.Context) { + for _, f := range *fs { + f.DecRef(ctx) + } + *fs = nil +} + +// rightsFDsVFS2 gets up to the specified maximum number of FDs. +func rightsFDsVFS2(t *kernel.Task, rights SCMRightsVFS2, cloexec bool, max int) ([]int32, bool) { + files, trunc := rights.Files(t, max) + fds := make([]int32, 0, len(files)) + for i := 0; i < max && len(files) > 0; i++ { + fd, err := t.NewFDFromVFS2(0, files[0], kernel.FDFlags{ + CloseOnExec: cloexec, + }) + files[0].DecRef(t) + files = files[1:] + if err != nil { + t.Warningf("Error inserting FD: %v", err) + // This is what Linux does. + break + } + + fds = append(fds, int32(fd)) + } + return fds, trunc +} + +// PackRightsVFS2 packs as many FDs as will fit into the unused capacity of buf. +func PackRightsVFS2(t *kernel.Task, rights SCMRightsVFS2, cloexec bool, buf []byte, flags int) ([]byte, int) { + maxFDs := (cap(buf) - len(buf) - linux.SizeOfControlMessageHeader) / 4 + // Linux does not return any FDs if none fit. + if maxFDs <= 0 { + flags |= linux.MSG_CTRUNC + return buf, flags + } + fds, trunc := rightsFDsVFS2(t, rights, cloexec, maxFDs) + if trunc { + flags |= linux.MSG_CTRUNC + } + align := t.Arch().Width() + return putCmsg(buf, flags, linux.SCM_RIGHTS, align, fds) +} + +// NewVFS2 creates default control messages if needed. +func NewVFS2(t *kernel.Task, socketOrEndpoint interface{}, rights SCMRightsVFS2) transport.ControlMessages { + return transport.ControlMessages{ + Credentials: makeCreds(t, socketOrEndpoint), + Rights: rights, + } +} diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD index c1b20eaf8..8448ea401 100644 --- a/pkg/sentry/socket/hostinet/BUILD +++ b/pkg/sentry/socket/hostinet/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -10,31 +10,41 @@ go_library( "save_restore.go", "socket.go", "socket_unsafe.go", + "socket_vfs2.go", + "sockopt_impl.go", "stack.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/hostinet", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/binary", + "//pkg/context", "//pkg/fdnotifier", "//pkg/log", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", + "//pkg/sentry/fsimpl/sockfs", + "//pkg/sentry/hostfd", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", - "//pkg/sentry/safemem", "//pkg/sentry/socket", - "//pkg/sentry/usermem", + "//pkg/sentry/socket/control", + "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", + "//pkg/tcpip/stack", + "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go index 92beb1bcf..242e6bf76 100644 --- a/pkg/sentry/socket/hostinet/socket.go +++ b/pkg/sentry/socket/hostinet/socket.go @@ -18,21 +18,26 @@ import ( "fmt" "syscall" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fdnotifier" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const ( @@ -41,8 +46,14 @@ const ( // sizeofSockaddr is the size in bytes of the largest sockaddr type // supported by this package. sizeofSockaddr = syscall.SizeofSockaddrInet6 // sizeof(sockaddr_in6) > sizeof(sockaddr_in) + + // maxControlLen is the maximum size of a control message buffer used in a + // recvmsg or sendmsg syscall. + maxControlLen = 1024 ) +// LINT.IfChange + // socketOperations implements fs.FileOperations and socket.Socket for a socket // implemented using a host socket. type socketOperations struct { @@ -53,55 +64,74 @@ type socketOperations struct { fsutil.FileNoSplice `state:"nosave"` fsutil.FileNoopFlush `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + socketOpsCommon +} + +// socketOpsCommon contains the socket operations common to VFS1 and VFS2. +// +// +stateify savable +type socketOpsCommon struct { socket.SendReceiveTimeout family int // Read-only. stype linux.SockType // Read-only. protocol int // Read-only. - fd int // must be O_NONBLOCK queue waiter.Queue + + // fd is the host socket fd. It must have O_NONBLOCK, so that operations + // will return EWOULDBLOCK instead of blocking on the host. This allows us to + // handle blocking behavior independently in the sentry. + fd int } var _ = socket.Socket(&socketOperations{}) func newSocketFile(ctx context.Context, family int, stype linux.SockType, protocol int, fd int, nonblock bool) (*fs.File, *syserr.Error) { s := &socketOperations{ - family: family, - stype: stype, - protocol: protocol, - fd: fd, + socketOpsCommon: socketOpsCommon{ + family: family, + stype: stype, + protocol: protocol, + fd: fd, + }, } if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil { return nil, syserr.FromError(err) } dirent := socket.NewDirent(ctx, socketDevice) - defer dirent.DecRef() + defer dirent.DecRef(ctx) return fs.NewFile(ctx, dirent, fs.FileFlags{NonBlocking: nonblock, Read: true, Write: true, NonSeekable: true}, s), nil } // Release implements fs.FileOperations.Release. -func (s *socketOperations) Release() { +func (s *socketOpsCommon) Release(context.Context) { fdnotifier.RemoveFD(int32(s.fd)) syscall.Close(s.fd) } // Readiness implements waiter.Waitable.Readiness. -func (s *socketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { +func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { return fdnotifier.NonBlockingPoll(int32(s.fd), mask) } // EventRegister implements waiter.Waitable.EventRegister. -func (s *socketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { +func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { s.queue.EventRegister(e, mask) fdnotifier.UpdateFD(int32(s.fd)) } // EventUnregister implements waiter.Waitable.EventUnregister. -func (s *socketOperations) EventUnregister(e *waiter.Entry) { +func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { s.queue.EventUnregister(e) fdnotifier.UpdateFD(int32(s.fd)) } +// Ioctl implements fs.FileOperations.Ioctl. +func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return ioctl(ctx, s.fd, io, args) +} + // Read implements fs.FileOperations.Read. func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { n, err := dst.CopyOutFrom(ctx, safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { @@ -120,7 +150,7 @@ func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS } return uint64(n), nil } - return readv(s.fd, iovecsFromBlockSeq(dsts)) + return readv(s.fd, safemem.IovecsFromBlockSeq(dsts)) })) return int64(n), err } @@ -143,13 +173,13 @@ func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO } return uint64(n), nil } - return writev(s.fd, iovecsFromBlockSeq(srcs)) + return writev(s.fd, safemem.IovecsFromBlockSeq(srcs)) })) return int64(n), err } // Connect implements socket.Socket.Connect. -func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { +func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { if len(sockaddr) > sizeofSockaddr { sockaddr = sockaddr[:sizeofSockaddr] } @@ -189,7 +219,7 @@ func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo } // Accept implements socket.Socket.Accept. -func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { var peerAddr linux.SockAddr var peerAddrBuf []byte var peerAddrlen uint32 @@ -203,7 +233,7 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } // Conservatively ignore all flags specified by the application and add - // SOCK_NONBLOCK since socketOperations requires it. + // SOCK_NONBLOCK since socketOpsCommon requires it. fd, syscallErr := accept4(s.fd, peerAddrPtr, peerAddrlenPtr, syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC) if blocking { var ch chan struct{} @@ -229,23 +259,41 @@ func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, return 0, peerAddr, peerAddrlen, syserr.FromError(syscallErr) } - f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0) - if err != nil { - syscall.Close(fd) - return 0, nil, 0, err - } - defer f.DecRef() + var ( + kfd int32 + kerr error + ) + if kernel.VFS2Enabled { + f, err := newVFS2Socket(t, s.family, s.stype, s.protocol, fd, uint32(flags&syscall.SOCK_NONBLOCK)) + if err != nil { + syscall.Close(fd) + return 0, nil, 0, err + } + defer f.DecRef(t) - kfd, kerr := t.NewFDFrom(0, f, kernel.FDFlags{ - CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, - }) - t.Kernel().RecordSocket(f) + kfd, kerr = t.NewFDFromVFS2(0, f, kernel.FDFlags{ + CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, + }) + t.Kernel().RecordSocketVFS2(f) + } else { + f, err := newSocketFile(t, s.family, s.stype, s.protocol, fd, flags&syscall.SOCK_NONBLOCK != 0) + if err != nil { + syscall.Close(fd) + return 0, nil, 0, err + } + defer f.DecRef(t) + + kfd, kerr = t.NewFDFrom(0, f, kernel.FDFlags{ + CloseOnExec: flags&syscall.SOCK_CLOEXEC != 0, + }) + t.Kernel().RecordSocket(f) + } return kfd, peerAddr, peerAddrlen, syserr.FromError(kerr) } // Bind implements socket.Socket.Bind. -func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { +func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if len(sockaddr) > sizeofSockaddr { sockaddr = sockaddr[:sizeofSockaddr] } @@ -258,12 +306,12 @@ func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } // Listen implements socket.Socket.Listen. -func (s *socketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { return syserr.FromError(syscall.Listen(s.fd, backlog)) } // Shutdown implements socket.Socket.Shutdown. -func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { switch how { case syscall.SHUT_RD, syscall.SHUT_WR, syscall.SHUT_RDWR: return syserr.FromError(syscall.Shutdown(s.fd, how)) @@ -273,34 +321,40 @@ func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { if outLen < 0 { return nil, syserr.ErrInvalidArgument } - // Whitelist options and constrain option length. - var optlen int + // Only allow known and safe options. + optlen := getSockOptLen(t, level, name) switch level { - case syscall.SOL_IPV6: + case linux.SOL_IP: + switch name { + case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO: + optlen = sizeofInt32 + } + case linux.SOL_IPV6: switch name { - case syscall.IPV6_V6ONLY: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: optlen = sizeofInt32 } - case syscall.SOL_SOCKET: + case linux.SOL_SOCKET: switch name { - case syscall.SO_ERROR, syscall.SO_KEEPALIVE, syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR: + case linux.SO_ERROR, linux.SO_KEEPALIVE, linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: optlen = sizeofInt32 - case syscall.SO_LINGER: + case linux.SO_LINGER: optlen = syscall.SizeofLinger } - case syscall.SOL_TCP: + case linux.SOL_TCP: switch name { - case syscall.TCP_NODELAY: + case linux.TCP_NODELAY: optlen = sizeofInt32 - case syscall.TCP_INFO: + case linux.TCP_INFO: optlen = int(linux.SizeOfTCPInfo) } } + if optlen == 0 { return nil, syserr.ErrProtocolNotAvailable // ENOPROTOOPT } @@ -312,30 +366,39 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt if err != nil { return nil, syserr.FromError(err) } - return opt, nil + optP := primitive.ByteSlice(opt) + return &optP, nil } // SetSockOpt implements socket.Socket.SetSockOpt. -func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { - // Whitelist options and constrain option length. - var optlen int +func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { + // Only allow known and safe options. + optlen := setSockOptLen(t, level, name) switch level { - case syscall.SOL_IPV6: + case linux.SOL_IP: switch name { - case syscall.IPV6_V6ONLY: + case linux.IP_TOS, linux.IP_RECVTOS: optlen = sizeofInt32 + case linux.IP_PKTINFO: + optlen = linux.SizeOfControlMessageIPPacketInfo } - case syscall.SOL_SOCKET: + case linux.SOL_IPV6: switch name { - case syscall.SO_SNDBUF, syscall.SO_RCVBUF, syscall.SO_REUSEADDR: + case linux.IPV6_TCLASS, linux.IPV6_RECVTCLASS, linux.IPV6_V6ONLY: optlen = sizeofInt32 } - case syscall.SOL_TCP: + case linux.SOL_SOCKET: switch name { - case syscall.TCP_NODELAY: + case linux.SO_SNDBUF, linux.SO_RCVBUF, linux.SO_REUSEADDR: + optlen = sizeofInt32 + } + case linux.SOL_TCP: + switch name { + case linux.TCP_NODELAY: optlen = sizeofInt32 } } + if optlen == 0 { // Pretend to accept socket options we don't understand. This seems // dangerous, but it's what netstack does... @@ -354,11 +417,11 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [ } // RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { - // Whitelist flags. +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { + // Only allow known and safe flags. // // FIXME(jamieliu): We can't support MSG_ERRQUEUE because it uses ancillary - // messages that netstack/tcpip/transport/unix doesn't understand. Kill the + // messages that gvisor/pkg/tcpip/transport/unix doesn't understand. Kill the // Socket interface's dependence on netstack. if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_PEEK|syscall.MSG_TRUNC) != 0 { return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrInvalidArgument @@ -370,6 +433,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags senderAddrBuf = make([]byte, sizeofSockaddr) } + var controlBuf []byte var msgFlags int recvmsgToBlocks := safemem.ReaderFunc(func(dsts safemem.BlockSeq) (uint64, error) { @@ -384,12 +448,7 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // We always do a non-blocking recv*(). sysflags := flags | syscall.MSG_DONTWAIT - if dsts.NumBlocks() == 1 { - // Skip allocating []syscall.Iovec. - return recvfrom(s.fd, dsts.Head().ToSlice(), sysflags, &senderAddrBuf) - } - - iovs := iovecsFromBlockSeq(dsts) + iovs := safemem.IovecsFromBlockSeq(dsts) msg := syscall.Msghdr{ Iov: &iovs[0], Iovlen: uint64(len(iovs)), @@ -398,12 +457,21 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags msg.Name = &senderAddrBuf[0] msg.Namelen = uint32(len(senderAddrBuf)) } + if controlLen > 0 { + if controlLen > maxControlLen { + controlLen = maxControlLen + } + controlBuf = make([]byte, controlLen) + msg.Control = &controlBuf[0] + msg.Controllen = controlLen + } n, err := recvmsg(s.fd, &msg, sysflags) if err != nil { return 0, err } senderAddrBuf = senderAddrBuf[:msg.Namelen] msgFlags = int(msg.Flags) + controlLen = uint64(msg.Controllen) return n, nil }) @@ -429,36 +497,75 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags n, err = dst.CopyOutFrom(t, recvmsgToBlocks) } } - - // We don't allow control messages. - msgFlags &^= linux.MSG_CTRUNC + if err != nil { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) + } if senderRequested { senderAddr = socket.UnmarshalSockAddr(s.family, senderAddrBuf) } - return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), socket.ControlMessages{}, syserr.FromError(err) + + unixControlMessages, err := unix.ParseSocketControlMessage(controlBuf[:controlLen]) + if err != nil { + return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) + } + + controlMessages := socket.ControlMessages{} + for _, unixCmsg := range unixControlMessages { + switch unixCmsg.Header.Level { + case syscall.SOL_IP: + switch unixCmsg.Header.Type { + case syscall.IP_TOS: + controlMessages.IP.HasTOS = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS) + + case syscall.IP_PKTINFO: + controlMessages.IP.HasIPPacketInfo = true + var packetInfo linux.ControlMessageIPPacketInfo + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo) + controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo) + } + + case syscall.SOL_IPV6: + switch unixCmsg.Header.Type { + case syscall.IPV6_TCLASS: + controlMessages.IP.HasTClass = true + binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTClass], usermem.ByteOrder, &controlMessages.IP.TClass) + } + } + } + + return int(n), msgFlags, senderAddr, uint32(len(senderAddrBuf)), controlMessages, nil } // SendMsg implements socket.Socket.SendMsg. -func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { - // Whitelist flags. +func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { + // Only allow known and safe flags. if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 { return 0, syserr.ErrInvalidArgument } + space := uint64(control.CmsgsSpace(t, controlMessages)) + if space > maxControlLen { + space = maxControlLen + } + controlBuf := make([]byte, 0, space) + // PackControlMessages will append up to space bytes to controlBuf. + controlBuf = control.PackControlMessages(t, controlMessages, controlBuf) + sendmsgFromBlocks := safemem.WriterFunc(func(srcs safemem.BlockSeq) (uint64, error) { // Refuse to do anything if any part of src.Addrs was unusable. if uint64(src.NumBytes()) != srcs.NumBytes() { return 0, nil } - if srcs.IsEmpty() { + if srcs.IsEmpty() && len(controlBuf) == 0 { return 0, nil } // We always do a non-blocking send*(). sysflags := flags | syscall.MSG_DONTWAIT - if srcs.NumBlocks() == 1 { + if srcs.NumBlocks() == 1 && len(controlBuf) == 0 { // Skip allocating []syscall.Iovec. src := srcs.Head() n, _, errno := syscall.Syscall6(syscall.SYS_SENDTO, uintptr(s.fd), src.Addr(), uintptr(src.Len()), uintptr(sysflags), uintptr(firstBytePtr(to)), uintptr(len(to))) @@ -468,7 +575,7 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return uint64(n), nil } - iovs := iovecsFromBlockSeq(srcs) + iovs := safemem.IovecsFromBlockSeq(srcs) msg := syscall.Msghdr{ Iov: &iovs[0], Iovlen: uint64(len(iovs)), @@ -477,6 +584,10 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] msg.Name = &to[0] msg.Namelen = uint32(len(to)) } + if len(controlBuf) != 0 { + msg.Control = &controlBuf[0] + msg.Controllen = uint64(len(controlBuf)) + } return sendmsg(s.fd, &msg, sysflags) }) @@ -509,21 +620,6 @@ func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] return int(n), syserr.FromError(err) } -func iovecsFromBlockSeq(bs safemem.BlockSeq) []syscall.Iovec { - iovs := make([]syscall.Iovec, 0, bs.NumBlocks()) - for ; !bs.IsEmpty(); bs = bs.Tail() { - b := bs.Head() - iovs = append(iovs, syscall.Iovec{ - Base: &b.ToSlice()[0], - Len: uint64(b.Len()), - }) - // We don't need to care about b.NeedSafecopy(), because the host - // kernel will handle such address ranges just fine (by returning - // EFAULT). - } - return iovs -} - func translateIOSyscallError(err error) error { if err == syscall.EAGAIN || err == syscall.EWOULDBLOCK { return syserror.ErrWouldBlock @@ -532,7 +628,7 @@ func translateIOSyscallError(err error) error { } // State implements socket.Socket.State. -func (s *socketOperations) State() uint32 { +func (s *socketOpsCommon) State() uint32 { info := linux.TCPInfo{} buf, err := getsockopt(s.fd, syscall.SOL_TCP, syscall.TCP_INFO, linux.SizeOfTCPInfo) if err != nil { @@ -554,7 +650,7 @@ func (s *socketOperations) State() uint32 { } // Type implements socket.Socket.Type. -func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) { +func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) { return s.family, s.stype, s.protocol } @@ -610,8 +706,11 @@ func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int return nil, nil, nil } +// LINT.ThenChange(./socket_vfs2.go) + func init() { for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} { socket.RegisterProvider(family, &socketProvider{family}) + socket.RegisterProviderVFS2(family, &socketProviderVFS2{family}) } } diff --git a/pkg/sentry/socket/hostinet/socket_unsafe.go b/pkg/sentry/socket/hostinet/socket_unsafe.go index e69ec38c2..3f420c2ec 100644 --- a/pkg/sentry/socket/hostinet/socket_unsafe.go +++ b/pkg/sentry/socket/hostinet/socket_unsafe.go @@ -19,14 +19,13 @@ import ( "unsafe" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) func firstBytePtr(bs []byte) unsafe.Pointer { @@ -54,12 +53,11 @@ func writev(fd int, srcs []syscall.Iovec) (uint64, error) { return uint64(n), nil } -// Ioctl implements fs.FileOperations.Ioctl. -func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { +func ioctl(ctx context.Context, fd int, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { switch cmd := uintptr(args[1].Int()); cmd { case syscall.TIOCINQ, syscall.TIOCOUTQ: var val int32 - if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(s.fd), cmd, uintptr(unsafe.Pointer(&val))); errno != 0 { + if _, _, errno := syscall.Syscall(syscall.SYS_IOCTL, uintptr(fd), cmd, uintptr(unsafe.Pointer(&val))); errno != 0 { return 0, translateIOSyscallError(errno) } var buf [4]byte @@ -93,7 +91,7 @@ func getsockopt(fd int, level, name int, optlen int) ([]byte, error) { } // GetSockName implements socket.Socket.GetSockName. -func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr := make([]byte, sizeofSockaddr) addrlen := uint32(len(addr)) _, _, errno := syscall.Syscall(syscall.SYS_GETSOCKNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen))) @@ -104,7 +102,7 @@ func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, } // GetPeerName implements socket.Socket.GetPeerName. -func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr := make([]byte, sizeofSockaddr) addrlen := uint32(len(addr)) _, _, errno := syscall.Syscall(syscall.SYS_GETPEERNAME, uintptr(s.fd), uintptr(unsafe.Pointer(&addr[0])), uintptr(unsafe.Pointer(&addrlen))) diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go new file mode 100644 index 000000000..8a1d52ebf --- /dev/null +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -0,0 +1,203 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hostinet + +import ( + "syscall" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" + "gvisor.dev/gvisor/pkg/sentry/hostfd" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +type socketVFS2 struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.LockFD + + // We store metadata for hostinet sockets internally. Technically, we should + // access metadata (e.g. through stat, chmod) on the host for correctness, + // but this is not very useful for inet socket fds, which do not belong to a + // concrete file anyway. + vfs.DentryMetadataFileDescriptionImpl + + socketOpsCommon +} + +var _ = socket.SocketVFS2(&socketVFS2{}) + +func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) { + mnt := t.Kernel().SocketMount() + d := sockfs.NewDentry(t.Credentials(), mnt) + + s := &socketVFS2{ + socketOpsCommon: socketOpsCommon{ + family: family, + stype: stype, + protocol: protocol, + fd: fd, + }, + } + s.LockFD.Init(&vfs.FileLocks{}) + if err := fdnotifier.AddFD(int32(fd), &s.queue); err != nil { + return nil, syserr.FromError(err) + } + vfsfd := &s.vfsfd + if err := vfsfd.Init(s, linux.O_RDWR|(flags&linux.O_NONBLOCK), mnt, d, &vfs.FileDescriptionOptions{ + DenyPRead: true, + DenyPWrite: true, + UseDentryMetadata: true, + }); err != nil { + fdnotifier.RemoveFD(int32(s.fd)) + return nil, syserr.FromError(err) + } + return vfsfd, nil +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *socketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { + return s.socketOpsCommon.Readiness(mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *socketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + s.socketOpsCommon.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *socketVFS2) EventUnregister(e *waiter.Entry) { + s.socketOpsCommon.EventUnregister(e) +} + +// Ioctl implements vfs.FileDescriptionImpl. +func (s *socketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return ioctl(ctx, s.fd, uio, args) +} + +// Allocate implements vfs.FileDescriptionImpl.Allocate. +func (s *socketVFS2) Allocate(ctx context.Context, mode, offset, length uint64) error { + return syserror.ENODEV +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (s *socketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Read implements vfs.FileDescriptionImpl. +func (s *socketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + reader := hostfd.GetReadWriterAt(int32(s.fd), -1, opts.Flags) + n, err := dst.CopyOutFrom(ctx, reader) + hostfd.PutReadWriterAt(reader) + return int64(n), err +} + +// PWrite implements vfs.FileDescriptionImpl. +func (s *socketVFS2) PWrite(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Write implements vfs.FileDescriptionImpl. +func (s *socketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + writer := hostfd.GetReadWriterAt(int32(s.fd), -1, opts.Flags) + n, err := src.CopyInTo(ctx, writer) + hostfd.PutReadWriterAt(writer) + return int64(n), err +} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *socketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *socketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} + +type socketProviderVFS2 struct { + family int +} + +// Socket implements socket.ProviderVFS2.Socket. +func (p *socketProviderVFS2) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) { + // Check that we are using the host network stack. + stack := t.NetworkContext() + if stack == nil { + return nil, nil + } + if _, ok := stack.(*Stack); !ok { + return nil, nil + } + + // Only accept TCP and UDP. + stype := stypeflags & linux.SOCK_TYPE_MASK + switch stype { + case syscall.SOCK_STREAM: + switch protocol { + case 0, syscall.IPPROTO_TCP: + // ok + default: + return nil, nil + } + case syscall.SOCK_DGRAM: + switch protocol { + case 0, syscall.IPPROTO_UDP: + // ok + default: + return nil, nil + } + default: + return nil, nil + } + + // Conservatively ignore all flags specified by the application and add + // SOCK_NONBLOCK since socketOperations requires it. Pass a protocol of 0 + // to simplify the syscall filters, since 0 and IPPROTO_* are equivalent. + fd, err := syscall.Socket(p.family, int(stype)|syscall.SOCK_NONBLOCK|syscall.SOCK_CLOEXEC, 0) + if err != nil { + return nil, syserr.FromError(err) + } + return newVFS2Socket(t, p.family, stype, protocol, fd, uint32(stypeflags&syscall.SOCK_NONBLOCK)) +} + +// Pair implements socket.Provider.Pair. +func (p *socketProviderVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) { + // Not supported by AF_INET/AF_INET6. + return nil, nil, nil +} diff --git a/pkg/sentry/socket/rpcinet/device.go b/pkg/sentry/socket/hostinet/sockopt_impl.go index 8cfd5f6e5..8a783712e 100644 --- a/pkg/sentry/socket/rpcinet/device.go +++ b/pkg/sentry/socket/hostinet/sockopt_impl.go @@ -1,4 +1,4 @@ -// Copyright 2018 The gVisor Authors. +// Copyright 2020 The gVisor Authors. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,8 +12,16 @@ // See the License for the specific language governing permissions and // limitations under the License. -package rpcinet +package hostinet -import "gvisor.dev/gvisor/pkg/sentry/device" +import ( + "gvisor.dev/gvisor/pkg/sentry/kernel" +) -var socketDevice = device.NewAnonDevice() +func getSockOptLen(t *kernel.Task, level, name int) int { + return 0 // No custom options. +} + +func setSockOptLen(t *kernel.Task, level, name int) int { + return 0 // No custom options. +} diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 4b460d30e..3d3fabb30 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -25,15 +25,16 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" ) var defaultRecvBufSize = inet.TCPBufferSize{ @@ -55,6 +56,7 @@ type Stack struct { interfaceAddrs map[int32][]inet.InterfaceAddr routes []inet.Route supportsIPv6 bool + tcpRecovery inet.TCPLossRecovery tcpRecvBufSize inet.TCPBufferSize tcpSendBufSize inet.TCPBufferSize tcpSACKEnabled bool @@ -128,6 +130,13 @@ func (s *Stack) Configure() error { log.Warningf("Failed to read if IPv4 forwarding is enabled, setting to false") } + s.ipv4Forwarding = false + if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv4/ip_forward"); err == nil { + s.ipv4Forwarding = strings.TrimSpace(string(ipForwarding)) != "0" + } else { + log.Warningf("Failed to read if IPv4 forwarding is enabled, setting to false") + } + return nil } @@ -321,6 +330,11 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { return addrs } +// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. +func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + return syserror.EACCES +} + // SupportsIPv6 implements inet.Stack.SupportsIPv6. func (s *Stack) SupportsIPv6() bool { return s.supportsIPv6 @@ -356,6 +370,16 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserror.EACCES } +// TCPRecovery implements inet.Stack.TCPRecovery. +func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { + return s.tcpRecovery, nil +} + +// SetTCPRecovery implements inet.Stack.SetTCPRecovery. +func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { + return syserror.EACCES +} + // getLine reads one line from proc file, with specified prefix. // The last argument, withHeader, specifies if it contains line header. func getLine(f *os.File, prefix string, withHeader bool) string { @@ -455,6 +479,15 @@ func (s *Stack) RouteTable() []inet.Route { // Resume implements inet.Stack.Resume. func (s *Stack) Resume() {} +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} + // Forwarding implements inet.Stack.Forwarding. func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { switch protocol { diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 5eb06bbf4..721094bbf 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -1,24 +1,29 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "netfilter", srcs = [ + "extensions.go", "netfilter.go", + "owner_matcher.go", + "targets.go", + "tcp_matcher.go", + "udp_matcher.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netfilter", # This target depends on netstack and should only be used by epsocket, # which is allowed to depend on netstack. visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/binary", + "//pkg/log", "//pkg/sentry/kernel", - "//pkg/sentry/usermem", "//pkg/syserr", "//pkg/tcpip", - "//pkg/tcpip/iptables", + "//pkg/tcpip/header", "//pkg/tcpip/stack", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go new file mode 100644 index 000000000..0336a32d8 --- /dev/null +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -0,0 +1,95 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" +) + +// TODO(gvisor.dev/issue/170): The following per-matcher params should be +// supported: +// - Table name +// - Match size +// - User size +// - Hooks +// - Proto +// - Family + +// matchMaker knows how to (un)marshal the matcher named name(). +type matchMaker interface { + // name is the matcher name as stored in the xt_entry_match struct. + name() string + + // marshal converts from an stack.Matcher to an ABI struct. + marshal(matcher stack.Matcher) []byte + + // unmarshal converts from the ABI matcher struct to an + // stack.Matcher. + unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) +} + +// matchMakers maps the name of supported matchers to the matchMaker that +// marshals and unmarshals it. It is immutable after package initialization. +var matchMakers = map[string]matchMaker{} + +// registermatchMaker should be called by match extensions to register them +// with the netfilter package. +func registerMatchMaker(mm matchMaker) { + if _, ok := matchMakers[mm.name()]; ok { + panic(fmt.Sprintf("Multiple matches registered with name %q.", mm.name())) + } + matchMakers[mm.name()] = mm +} + +func marshalMatcher(matcher stack.Matcher) []byte { + matchMaker, ok := matchMakers[matcher.Name()] + if !ok { + panic(fmt.Sprintf("Unknown matcher of type %T.", matcher)) + } + return matchMaker.marshal(matcher) +} + +// marshalEntryMatch creates a marshalled XTEntryMatch with the given name and +// data appended at the end. +func marshalEntryMatch(name string, data []byte) []byte { + nflog("marshaling matcher %q", name) + + // We have to pad this struct size to a multiple of 8 bytes. + size := binary.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8) + matcher := linux.KernelXTEntryMatch{ + XTEntryMatch: linux.XTEntryMatch{ + MatchSize: uint16(size), + }, + Data: data, + } + copy(matcher.Name[:], name) + + buf := make([]byte, 0, size) + buf = binary.Marshal(buf, usermem.ByteOrder, matcher) + return append(buf, make([]byte, size-len(buf))...) +} + +func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) { + matchMaker, ok := matchMakers[match.Name.String()] + if !ok { + return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String()) + } + return matchMaker.unmarshal(buf, filter) +} diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 9f87c32f1..e91b0624c 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -17,270 +17,506 @@ package netfilter import ( + "bytes" + "errors" "fmt" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel" - "gvisor.dev/gvisor/pkg/sentry/usermem" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/iptables" + "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" ) -// errorTargetName is used to mark targets as error targets. Error targets -// shouldn't be reached - an error has occurred if we fall through to one. -const errorTargetName = "ERROR" - -// metadata is opaque to netstack. It holds data that we need to translate -// between Linux's and netstack's iptables representations. -type metadata struct { - HookEntry [linux.NF_INET_NUMHOOKS]uint32 - Underflow [linux.NF_INET_NUMHOOKS]uint32 - NumEntries uint32 - Size uint32 +// enableLogging controls whether to log the (de)serialization of netfilter +// structs between userspace and netstack. These logs are useful when +// developing iptables, but can pollute sentry logs otherwise. +const enableLogging = false + +// emptyFilter is for comparison with a rule's filters to determine whether it +// is also empty. It is immutable. +var emptyFilter = stack.IPHeaderFilter{ + Dst: "\x00\x00\x00\x00", + DstMask: "\x00\x00\x00\x00", + Src: "\x00\x00\x00\x00", + SrcMask: "\x00\x00\x00\x00", +} + +// nflog logs messages related to the writing and reading of iptables. +func nflog(format string, args ...interface{}) { + if enableLogging && log.IsLogging(log.Debug) { + log.Debugf("netfilter: "+format, args...) + } } // GetInfo returns information about iptables. -func GetInfo(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) { +func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) { // Read in the struct and table name. var info linux.IPTGetinfo - if _, err := t.CopyIn(outPtr, &info); err != nil { + if _, err := info.CopyIn(t, outPtr); err != nil { return linux.IPTGetinfo{}, syserr.FromError(err) } - // Find the appropriate table. - table, err := findTable(ep, info.TableName()) + _, info, err := convertNetstackToBinary(stack, info.Name) if err != nil { - return linux.IPTGetinfo{}, err + nflog("couldn't convert iptables: %v", err) + return linux.IPTGetinfo{}, syserr.ErrInvalidArgument } - // Get the hooks that apply to this table. - info.ValidHooks = table.ValidHooks() - - // Grab the metadata struct, which is used to store information (e.g. - // the number of entries) that applies to the user's encoding of - // iptables, but not netstack's. - metadata := table.Metadata().(metadata) - - // Set values from metadata. - info.HookEntry = metadata.HookEntry - info.Underflow = metadata.Underflow - info.NumEntries = metadata.NumEntries - info.Size = metadata.Size - + nflog("returning info: %+v", info) return info, nil } // GetEntries returns netstack's iptables rules encoded for the iptables tool. -func GetEntries(t *kernel.Task, ep tcpip.Endpoint, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { +func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { // Read in the struct and table name. var userEntries linux.IPTGetEntries - if _, err := t.CopyIn(outPtr, &userEntries); err != nil { + if _, err := userEntries.CopyIn(t, outPtr); err != nil { + nflog("couldn't copy in entries %q", userEntries.Name) return linux.KernelIPTGetEntries{}, syserr.FromError(err) } - // Find the appropriate table. - table, err := findTable(ep, userEntries.TableName()) - if err != nil { - return linux.KernelIPTGetEntries{}, err - } - // Convert netstack's iptables rules to something that the iptables // tool can understand. - entries, _, err := convertNetstackToBinary(userEntries.TableName(), table) + entries, _, err := convertNetstackToBinary(stack, userEntries.Name) if err != nil { - return linux.KernelIPTGetEntries{}, err + nflog("couldn't read entries: %v", err) + return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } if binary.Size(entries) > uintptr(outLen) { + nflog("insufficient GetEntries output size: %d", uintptr(outLen)) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument } return entries, nil } -func findTable(ep tcpip.Endpoint, tableName string) (iptables.Table, *syserr.Error) { - ipt, err := ep.IPTables() - if err != nil { - return iptables.Table{}, syserr.FromError(err) - } - table, ok := ipt.Tables[tableName] +// convertNetstackToBinary converts the iptables as stored in netstack to the +// format expected by the iptables tool. Linux stores each table as a binary +// blob that can only be traversed by parsing a bit, reading some offsets, +// jumping to those offsets, parsing again, etc. +func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) { + table, ok := stack.IPTables().GetTable(tablename.String()) if !ok { - return iptables.Table{}, syserr.ErrInvalidArgument + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) + } + + var entries linux.KernelIPTGetEntries + var info linux.IPTGetinfo + info.ValidHooks = table.ValidHooks() + + // The table name has to fit in the struct. + if linux.XT_TABLE_MAXNAMELEN < len(tablename) { + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) } - return table, nil + copy(info.Name[:], tablename[:]) + copy(entries.Name[:], tablename[:]) + + for ruleIdx, rule := range table.Rules { + nflog("convert to binary: current offset: %d", entries.Size) + + // Is this a chain entry point? + for hook, hookRuleIdx := range table.BuiltinChains { + if hookRuleIdx == ruleIdx { + nflog("convert to binary: found hook %d at offset %d", hook, entries.Size) + info.HookEntry[hook] = entries.Size + } + } + // Is this a chain underflow point? + for underflow, underflowRuleIdx := range table.Underflows { + if underflowRuleIdx == ruleIdx { + nflog("convert to binary: found underflow %d at offset %d", underflow, entries.Size) + info.Underflow[underflow] = entries.Size + } + } + + // Each rule corresponds to an entry. + entry := linux.KernelIPTEntry{ + Entry: linux.IPTEntry{ + IP: linux.IPTIP{ + Protocol: uint16(rule.Filter.Protocol), + }, + NextOffset: linux.SizeOfIPTEntry, + TargetOffset: linux.SizeOfIPTEntry, + }, + } + copy(entry.Entry.IP.Dst[:], rule.Filter.Dst) + copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask) + copy(entry.Entry.IP.Src[:], rule.Filter.Src) + copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask) + copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface) + copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) + if rule.Filter.DstInvert { + entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP + } + if rule.Filter.SrcInvert { + entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP + } + if rule.Filter.OutputInterfaceInvert { + entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT + } + + for _, matcher := range rule.Matchers { + // Serialize the matcher and add it to the + // entry. + serialized := marshalMatcher(matcher) + nflog("convert to binary: matcher serialized as: %v", serialized) + if len(serialized)%8 != 0 { + panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher)) + } + entry.Elems = append(entry.Elems, serialized...) + entry.Entry.NextOffset += uint16(len(serialized)) + entry.Entry.TargetOffset += uint16(len(serialized)) + } + + // Serialize and append the target. + serialized := marshalTarget(rule.Target) + if len(serialized)%8 != 0 { + panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target)) + } + entry.Elems = append(entry.Elems, serialized...) + entry.Entry.NextOffset += uint16(len(serialized)) + + nflog("convert to binary: adding entry: %+v", entry) + + entries.Size += uint32(entry.Entry.NextOffset) + entries.Entrytable = append(entries.Entrytable, entry) + info.NumEntries++ + } + + nflog("convert to binary: finished with an marshalled size of %d", info.Size) + info.Size = entries.Size + return entries, info, nil } -// FillDefaultIPTables sets stack's IPTables to the default tables and -// populates them with metadata. -func FillDefaultIPTables(stack *stack.Stack) { - ipt := iptables.DefaultTables() +// SetEntries sets iptables rules for a single table. See +// net/ipv4/netfilter/ip_tables.c:translate_table for reference. +func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { + // Get the basic rules data (struct ipt_replace). + if len(optVal) < linux.SizeOfIPTReplace { + nflog("optVal has insufficient size for replace %d", len(optVal)) + return syserr.ErrInvalidArgument + } + var replace linux.IPTReplace + replaceBuf := optVal[:linux.SizeOfIPTReplace] + optVal = optVal[linux.SizeOfIPTReplace:] + binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace) + + // TODO(gvisor.dev/issue/170): Support other tables. + var table stack.Table + switch replace.Name.String() { + case stack.FilterTable: + table = stack.EmptyFilterTable() + case stack.NATTable: + table = stack.EmptyNATTable() + default: + nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) + return syserr.ErrInvalidArgument + } - // In order to fill in the metadata, we have to translate ipt from its - // netstack format to Linux's giant-binary-blob format. - for name, table := range ipt.Tables { - _, metadata, err := convertNetstackToBinary(name, table) + nflog("set entries: setting entries in table %q", replace.Name.String()) + + // Convert input into a list of rules and their offsets. + var offset uint32 + // offsets maps rule byte offsets to their position in table.Rules. + offsets := map[uint32]int{} + for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ { + nflog("set entries: processing entry at offset %d", offset) + + // Get the struct ipt_entry. + if len(optVal) < linux.SizeOfIPTEntry { + nflog("optVal has insufficient size for entry %d", len(optVal)) + return syserr.ErrInvalidArgument + } + var entry linux.IPTEntry + buf := optVal[:linux.SizeOfIPTEntry] + binary.Unmarshal(buf, usermem.ByteOrder, &entry) + initialOptValLen := len(optVal) + optVal = optVal[linux.SizeOfIPTEntry:] + + if entry.TargetOffset < linux.SizeOfIPTEntry { + nflog("entry has too-small target offset %d", entry.TargetOffset) + return syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): We should support more IPTIP + // filtering fields. + filter, err := filterFromIPTIP(entry.IP) if err != nil { - panic(fmt.Errorf("Unable to set default IP tables: %v", err)) + nflog("bad iptip: %v", err) + return syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): Matchers and targets can specify + // that they only work for certain protocols, hooks, tables. + // Get matchers. + matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry + if len(optVal) < int(matchersSize) { + nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal)) + return syserr.ErrInvalidArgument + } + matchers, err := parseMatchers(filter, optVal[:matchersSize]) + if err != nil { + nflog("failed to parse matchers: %v", err) + return syserr.ErrInvalidArgument + } + optVal = optVal[matchersSize:] + + // Get the target of the rule. + targetSize := entry.NextOffset - entry.TargetOffset + if len(optVal) < int(targetSize) { + nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) + return syserr.ErrInvalidArgument + } + target, err := parseTarget(filter, optVal[:targetSize]) + if err != nil { + nflog("failed to parse target: %v", err) + return syserr.ErrInvalidArgument + } + optVal = optVal[targetSize:] + + table.Rules = append(table.Rules, stack.Rule{ + Filter: filter, + Target: target, + Matchers: matchers, + }) + offsets[offset] = int(entryIdx) + offset += uint32(entry.NextOffset) + + if initialOptValLen-len(optVal) != int(entry.NextOffset) { + nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal)) + return syserr.ErrInvalidArgument } - table.SetMetadata(metadata) - ipt.Tables[name] = table } - stack.SetIPTables(ipt) -} + // Go through the list of supported hooks for this table and, for each + // one, set the rule it corresponds to. + for hook, _ := range replace.HookEntry { + if table.ValidHooks()&(1<<hook) != 0 { + hk := hookFromLinux(hook) + table.BuiltinChains[hk] = stack.HookUnset + table.Underflows[hk] = stack.HookUnset + for offset, ruleIdx := range offsets { + if offset == replace.HookEntry[hook] { + table.BuiltinChains[hk] = ruleIdx + } + if offset == replace.Underflow[hook] { + if !validUnderflow(table.Rules[ruleIdx]) { + nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx) + return syserr.ErrInvalidArgument + } + table.Underflows[hk] = ruleIdx + } + } + if ruleIdx := table.BuiltinChains[hk]; ruleIdx == stack.HookUnset { + nflog("hook %v is unset.", hk) + return syserr.ErrInvalidArgument + } + if ruleIdx := table.Underflows[hk]; ruleIdx == stack.HookUnset { + nflog("underflow %v is unset.", hk) + return syserr.ErrInvalidArgument + } + } + } -// convertNetstackToBinary converts the iptables as stored in netstack to the -// format expected by the iptables tool. Linux stores each table as a binary -// blob that can only be traversed by parsing a bit, reading some offsets, -// jumping to those offsets, parsing again, etc. -func convertNetstackToBinary(name string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, *syserr.Error) { - // Return values. - var entries linux.KernelIPTGetEntries - var meta metadata + // Add the user chains. + for ruleIdx, rule := range table.Rules { + if _, ok := rule.Target.(stack.UserChainTarget); !ok { + continue + } - // The table name has to fit in the struct. - if linux.XT_TABLE_MAXNAMELEN < len(name) { - return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument + // We found a user chain. Before inserting it into the table, + // check that: + // - There's some other rule after it. + // - There are no matchers. + if ruleIdx == len(table.Rules)-1 { + nflog("user chain must have a rule or default policy") + return syserr.ErrInvalidArgument + } + if len(table.Rules[ruleIdx].Matchers) != 0 { + nflog("user chain's first node must have no matchers") + return syserr.ErrInvalidArgument + } } - copy(entries.Name[:], name) - // Deal with the built in chains first (INPUT, OUTPUT, etc.). Each of - // these chains ends with an unconditional policy entry. - for hook := iptables.Prerouting; hook < iptables.NumHooks; hook++ { - chain, ok := table.BuiltinChains[hook] + // Set each jump to point to the appropriate rule. Right now they hold byte + // offsets. + for ruleIdx, rule := range table.Rules { + jump, ok := rule.Target.(JumpTarget) if !ok { - // This table doesn't support this hook. continue } - // Sanity check. - if len(chain.Rules) < 1 { - return linux.KernelIPTGetEntries{}, metadata{}, syserr.ErrInvalidArgument + // Find the rule corresponding to the jump rule offset. + jumpTo, ok := offsets[jump.Offset] + if !ok { + nflog("failed to find a rule to jump to") + return syserr.ErrInvalidArgument } + jump.RuleNum = jumpTo + rule.Target = jump + table.Rules[ruleIdx] = rule + } - for ruleIdx, rule := range chain.Rules { - // If this is the first rule of a builtin chain, set - // the metadata hook entry point. - if ruleIdx == 0 { - meta.HookEntry[hook] = entries.Size + // TODO(gvisor.dev/issue/170): Support other chains. + // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now, + // make sure all other chains point to ACCEPT rules. + for hook, ruleIdx := range table.BuiltinChains { + if hook := stack.Hook(hook); hook == stack.Forward || hook == stack.Postrouting { + if ruleIdx == stack.HookUnset { + continue } - - // Each rule corresponds to an entry. - entry := linux.KernelIPTEntry{ - IPTEntry: linux.IPTEntry{ - NextOffset: linux.SizeOfIPTEntry, - TargetOffset: linux.SizeOfIPTEntry, - }, + if !isUnconditionalAccept(table.Rules[ruleIdx]) { + nflog("hook %d is unsupported.", hook) + return syserr.ErrInvalidArgument } + } + } - for _, matcher := range rule.Matchers { - // Serialize the matcher and add it to the - // entry. - serialized := marshalMatcher(matcher) - entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) - entry.TargetOffset += uint16(len(serialized)) - } + // TODO(gvisor.dev/issue/170): Check the following conditions: + // - There are no loops. + // - There are no chains without an unconditional final rule. + // - There are no chains without an unconditional underflow rule. - // Serialize and append the target. - serialized := marshalTarget(rule.Target) - entry.Elems = append(entry.Elems, serialized...) - entry.NextOffset += uint16(len(serialized)) + return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table)) +} - // The underflow rule is the last rule in the chain, - // and is an unconditional rule (i.e. it matches any - // packet). This is enforced when saving iptables. - if ruleIdx == len(chain.Rules)-1 { - meta.Underflow[hook] = entries.Size - } +// parseMatchers parses 0 or more matchers from optVal. optVal should contain +// only the matchers. +func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) { + nflog("set entries: parsing matchers of size %d", len(optVal)) + var matchers []stack.Matcher + for len(optVal) > 0 { + nflog("set entries: optVal has len %d", len(optVal)) + + // Get the XTEntryMatch. + if len(optVal) < linux.SizeOfXTEntryMatch { + return nil, fmt.Errorf("optVal has insufficient size for entry match: %d", len(optVal)) + } + var match linux.XTEntryMatch + buf := optVal[:linux.SizeOfXTEntryMatch] + binary.Unmarshal(buf, usermem.ByteOrder, &match) + nflog("set entries: parsed entry match %q: %+v", match.Name.String(), match) + + // Check some invariants. + if match.MatchSize < linux.SizeOfXTEntryMatch { - entries.Size += uint32(entry.NextOffset) - entries.Entrytable = append(entries.Entrytable, entry) - meta.NumEntries++ + return nil, fmt.Errorf("match size is too small, must be at least %d", linux.SizeOfXTEntryMatch) + } + if len(optVal) < int(match.MatchSize) { + return nil, fmt.Errorf("optVal has insufficient size for match: %d", len(optVal)) } + // Parse the specific matcher. + matcher, err := unmarshalMatcher(match, filter, optVal[linux.SizeOfXTEntryMatch:match.MatchSize]) + if err != nil { + return nil, fmt.Errorf("failed to create matcher: %v", err) + } + matchers = append(matchers, matcher) + + // TODO(gvisor.dev/issue/170): Check the revision field. + optVal = optVal[match.MatchSize:] } - // TODO(gvisor.dev/issue/170): Deal with the user chains here. Each of - // these starts with an error node holding the chain's name and ends - // with an unconditional return. - - // Lastly, each table ends with an unconditional error target rule as - // its final entry. - errorEntry := linux.KernelIPTEntry{ - IPTEntry: linux.IPTEntry{ - NextOffset: linux.SizeOfIPTEntry, - TargetOffset: linux.SizeOfIPTEntry, - }, + if len(optVal) != 0 { + return nil, errors.New("optVal should be exhausted after parsing matchers") } - var errorTarget linux.XTErrorTarget - errorTarget.Target.TargetSize = linux.SizeOfXTErrorTarget - copy(errorTarget.ErrorName[:], errorTargetName) - copy(errorTarget.Target.Name[:], errorTargetName) - - // Serialize and add it to the list of entries. - errorTargetBuf := make([]byte, 0, linux.SizeOfXTErrorTarget) - serializedErrorTarget := binary.Marshal(errorTargetBuf, usermem.ByteOrder, errorTarget) - errorEntry.Elems = append(errorEntry.Elems, serializedErrorTarget...) - errorEntry.NextOffset += uint16(len(serializedErrorTarget)) - - entries.Size += uint32(errorEntry.NextOffset) - entries.Entrytable = append(entries.Entrytable, errorEntry) - meta.NumEntries++ - meta.Size = entries.Size - - return entries, meta, nil + + return matchers, nil } -func marshalMatcher(matcher iptables.Matcher) []byte { - switch matcher.(type) { - default: - // TODO(gvisor.dev/issue/170): We don't support any matchers yet, so - // any call to marshalMatcher will panic. - panic(fmt.Errorf("unknown matcher of type %T", matcher)) +func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { + if containsUnsupportedFields(iptip) { + return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip) + } + if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) + } + if len(iptip.Src) != header.IPv4AddressSize || len(iptip.SrcMask) != header.IPv4AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) + } + + n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) + if n == -1 { + n = len(iptip.OutputInterface) + } + ifname := string(iptip.OutputInterface[:n]) + + n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) + if n == -1 { + n = len(iptip.OutputInterfaceMask) } + ifnameMask := string(iptip.OutputInterfaceMask[:n]) + + return stack.IPHeaderFilter{ + Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), + Dst: tcpip.Address(iptip.Dst[:]), + DstMask: tcpip.Address(iptip.DstMask[:]), + DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0, + Src: tcpip.Address(iptip.Src[:]), + SrcMask: tcpip.Address(iptip.SrcMask[:]), + SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0, + OutputInterface: ifname, + OutputInterfaceMask: ifnameMask, + OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0, + }, nil } -func marshalTarget(target iptables.Target) []byte { - switch target.(type) { - case iptables.UnconditionalAcceptTarget: - return marshalUnconditionalAcceptTarget() +func containsUnsupportedFields(iptip linux.IPTIP) bool { + // The following features are supported: + // - Protocol + // - Dst and DstMask + // - Src and SrcMask + // - The inverse destination IP check flag + // - OutputInterface, OutputInterfaceMask and its inverse. + var emptyInterface = [linux.IFNAMSIZ]byte{} + // Disable any supported inverse flags. + inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT) + return iptip.InputInterface != emptyInterface || + iptip.InputInterfaceMask != emptyInterface || + iptip.Flags != 0 || + iptip.InverseFlags&^inverseMask != 0 +} + +func validUnderflow(rule stack.Rule) bool { + if len(rule.Matchers) != 0 { + return false + } + if rule.Filter != emptyFilter { + return false + } + switch rule.Target.(type) { + case stack.AcceptTarget, stack.DropTarget: + return true default: - panic(fmt.Errorf("unknown target of type %T", target)) + return false } } -func marshalUnconditionalAcceptTarget() []byte { - // The target's name will be the empty string. - target := linux.XTStandardTarget{ - Target: linux.XTEntryTarget{ - TargetSize: linux.SizeOfXTStandardTarget, - }, - Verdict: translateStandardVerdict(iptables.Accept), +func isUnconditionalAccept(rule stack.Rule) bool { + if !validUnderflow(rule) { + return false } - - ret := make([]byte, 0, linux.SizeOfXTStandardTarget) - return binary.Marshal(ret, usermem.ByteOrder, target) + _, ok := rule.Target.(stack.AcceptTarget) + return ok } -// translateStandardVerdict translates verdicts the same way as the iptables -// tool. -func translateStandardVerdict(verdict iptables.Verdict) int32 { - switch verdict { - case iptables.Accept: - return -linux.NF_ACCEPT - 1 - case iptables.Drop: - return -linux.NF_DROP - 1 - case iptables.Queue: - return -linux.NF_QUEUE - 1 - case iptables.Return: - return linux.NF_RETURN - case iptables.Jump: - // TODO(gvisor.dev/issue/170): Support Jump. - panic("Jump isn't supported yet") - default: - panic(fmt.Sprintf("unknown standard verdict: %d", verdict)) +func hookFromLinux(hook int) stack.Hook { + switch hook { + case linux.NF_INET_PRE_ROUTING: + return stack.Prerouting + case linux.NF_INET_LOCAL_IN: + return stack.Input + case linux.NF_INET_FORWARD: + return stack.Forward + case linux.NF_INET_LOCAL_OUT: + return stack.Output + case linux.NF_INET_POST_ROUTING: + return stack.Postrouting } + panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook)) } diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go new file mode 100644 index 000000000..1b4e0ad79 --- /dev/null +++ b/pkg/sentry/socket/netfilter/owner_matcher.go @@ -0,0 +1,149 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" +) + +const matcherNameOwner = "owner" + +func init() { + registerMatchMaker(ownerMarshaler{}) +} + +// ownerMarshaler implements matchMaker for owner matching. +type ownerMarshaler struct{} + +// name implements matchMaker.name. +func (ownerMarshaler) name() string { + return matcherNameOwner +} + +// marshal implements matchMaker.marshal. +func (ownerMarshaler) marshal(mr stack.Matcher) []byte { + matcher := mr.(*OwnerMatcher) + iptOwnerInfo := linux.IPTOwnerInfo{ + UID: matcher.uid, + GID: matcher.gid, + } + + // Support for UID and GID match. + if matcher.matchUID { + iptOwnerInfo.Match = linux.XT_OWNER_UID + if matcher.invertUID { + iptOwnerInfo.Invert = linux.XT_OWNER_UID + } + } + if matcher.matchGID { + iptOwnerInfo.Match |= linux.XT_OWNER_GID + if matcher.invertGID { + iptOwnerInfo.Invert |= linux.XT_OWNER_GID + } + } + + buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo) + return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, usermem.ByteOrder, iptOwnerInfo)) +} + +// unmarshal implements matchMaker.unmarshal. +func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { + if len(buf) < linux.SizeOfIPTOwnerInfo { + return nil, fmt.Errorf("buf has insufficient size for owner match: %d", len(buf)) + } + + // For alignment reasons, the match's total size may + // exceed what's strictly necessary to hold matchData. + var matchData linux.IPTOwnerInfo + binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], usermem.ByteOrder, &matchData) + nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData) + + var owner OwnerMatcher + owner.uid = matchData.UID + owner.gid = matchData.GID + + // Check flags. + if matchData.Match&linux.XT_OWNER_UID != 0 { + owner.matchUID = true + if matchData.Invert&linux.XT_OWNER_UID != 0 { + owner.invertUID = true + } + } + if matchData.Match&linux.XT_OWNER_GID != 0 { + owner.matchGID = true + if matchData.Invert&linux.XT_OWNER_GID != 0 { + owner.invertGID = true + } + } + + return &owner, nil +} + +type OwnerMatcher struct { + uid uint32 + gid uint32 + matchUID bool + matchGID bool + invertUID bool + invertGID bool +} + +// Name implements Matcher.Name. +func (*OwnerMatcher) Name() string { + return matcherNameOwner +} + +// Match implements Matcher.Match. +func (om *OwnerMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { + // Support only for OUTPUT chain. + // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also. + if hook != stack.Output { + return false, true + } + + // If the packet owner is not set, drop the packet. + if pkt.Owner == nil { + return false, true + } + + var matches bool + // Check for UID match. + if om.matchUID { + if pkt.Owner.UID() == om.uid { + matches = true + } + if matches == om.invertUID { + return false, false + } + } + + // Check for GID match. + if om.matchGID { + matches = false + if pkt.Owner.GID() == om.gid { + matches = true + } + if matches == om.invertGID { + return false, false + } + } + + return true, false +} diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go new file mode 100644 index 000000000..8ebdaff18 --- /dev/null +++ b/pkg/sentry/socket/netfilter/targets.go @@ -0,0 +1,282 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "errors" + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" +) + +// errorTargetName is used to mark targets as error targets. Error targets +// shouldn't be reached - an error has occurred if we fall through to one. +const errorTargetName = "ERROR" + +// redirectTargetName is used to mark targets as redirect targets. Redirect +// targets should be reached for only NAT and Mangle tables. These targets will +// change the destination port/destination IP for packets. +const redirectTargetName = "REDIRECT" + +func marshalTarget(target stack.Target) []byte { + switch tg := target.(type) { + case stack.AcceptTarget: + return marshalStandardTarget(stack.RuleAccept) + case stack.DropTarget: + return marshalStandardTarget(stack.RuleDrop) + case stack.ErrorTarget: + return marshalErrorTarget(errorTargetName) + case stack.UserChainTarget: + return marshalErrorTarget(tg.Name) + case stack.ReturnTarget: + return marshalStandardTarget(stack.RuleReturn) + case stack.RedirectTarget: + return marshalRedirectTarget(tg) + case JumpTarget: + return marshalJumpTarget(tg) + default: + panic(fmt.Errorf("unknown target of type %T", target)) + } +} + +func marshalStandardTarget(verdict stack.RuleVerdict) []byte { + nflog("convert to binary: marshalling standard target") + + // The target's name will be the empty string. + target := linux.XTStandardTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTStandardTarget, + }, + Verdict: translateFromStandardVerdict(verdict), + } + + ret := make([]byte, 0, linux.SizeOfXTStandardTarget) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + +func marshalErrorTarget(errorName string) []byte { + // This is an error target named error + target := linux.XTErrorTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTErrorTarget, + }, + } + copy(target.Name[:], errorName) + copy(target.Target.Name[:], errorTargetName) + + ret := make([]byte, 0, linux.SizeOfXTErrorTarget) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + +func marshalRedirectTarget(rt stack.RedirectTarget) []byte { + // This is a redirect target named redirect + target := linux.XTRedirectTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTRedirectTarget, + }, + } + copy(target.Target.Name[:], redirectTargetName) + + ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) + target.NfRange.RangeSize = 1 + if rt.RangeProtoSpecified { + target.NfRange.RangeIPV4.Flags |= linux.NF_NAT_RANGE_PROTO_SPECIFIED + } + // Convert port from little endian to big endian. + port := make([]byte, 2) + binary.LittleEndian.PutUint16(port, rt.MinPort) + target.NfRange.RangeIPV4.MinPort = binary.BigEndian.Uint16(port) + binary.LittleEndian.PutUint16(port, rt.MaxPort) + target.NfRange.RangeIPV4.MaxPort = binary.BigEndian.Uint16(port) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + +func marshalJumpTarget(jt JumpTarget) []byte { + nflog("convert to binary: marshalling jump target") + + // The target's name will be the empty string. + target := linux.XTStandardTarget{ + Target: linux.XTEntryTarget{ + TargetSize: linux.SizeOfXTStandardTarget, + }, + // Verdict is overloaded by the ABI. When positive, it holds + // the jump offset from the start of the table. + Verdict: int32(jt.Offset), + } + + ret := make([]byte, 0, linux.SizeOfXTStandardTarget) + return binary.Marshal(ret, usermem.ByteOrder, target) +} + +// translateFromStandardVerdict translates verdicts the same way as the iptables +// tool. +func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 { + switch verdict { + case stack.RuleAccept: + return -linux.NF_ACCEPT - 1 + case stack.RuleDrop: + return -linux.NF_DROP - 1 + case stack.RuleReturn: + return linux.NF_RETURN + default: + // TODO(gvisor.dev/issue/170): Support Jump. + panic(fmt.Sprintf("unknown standard verdict: %d", verdict)) + } +} + +// translateToStandardTarget translates from the value in a +// linux.XTStandardTarget to an stack.Verdict. +func translateToStandardTarget(val int32) (stack.Target, error) { + // TODO(gvisor.dev/issue/170): Support other verdicts. + switch val { + case -linux.NF_ACCEPT - 1: + return stack.AcceptTarget{}, nil + case -linux.NF_DROP - 1: + return stack.DropTarget{}, nil + case -linux.NF_QUEUE - 1: + return nil, errors.New("unsupported iptables verdict QUEUE") + case linux.NF_RETURN: + return stack.ReturnTarget{}, nil + default: + return nil, fmt.Errorf("unknown iptables verdict %d", val) + } +} + +// parseTarget parses a target from optVal. optVal should contain only the +// target. +func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) { + nflog("set entries: parsing target of size %d", len(optVal)) + if len(optVal) < linux.SizeOfXTEntryTarget { + return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal)) + } + var target linux.XTEntryTarget + buf := optVal[:linux.SizeOfXTEntryTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &target) + switch target.Name.String() { + case "": + // Standard target. + if len(optVal) != linux.SizeOfXTStandardTarget { + return nil, fmt.Errorf("optVal has wrong size for standard target %d", len(optVal)) + } + var standardTarget linux.XTStandardTarget + buf = optVal[:linux.SizeOfXTStandardTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget) + + if standardTarget.Verdict < 0 { + // A Verdict < 0 indicates a non-jump verdict. + return translateToStandardTarget(standardTarget.Verdict) + } + // A verdict >= 0 indicates a jump. + return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil + + case errorTargetName: + // Error target. + if len(optVal) != linux.SizeOfXTErrorTarget { + return nil, fmt.Errorf("optVal has insufficient size for error target %d", len(optVal)) + } + var errorTarget linux.XTErrorTarget + buf = optVal[:linux.SizeOfXTErrorTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget) + + // Error targets are used in 2 cases: + // * An actual error case. These rules have an error + // named errorTargetName. The last entry of the table + // is usually an error case to catch any packets that + // somehow fall through every rule. + // * To mark the start of a user defined chain. These + // rules have an error with the name of the chain. + switch name := errorTarget.Name.String(); name { + case errorTargetName: + nflog("set entries: error target") + return stack.ErrorTarget{}, nil + default: + // User defined chain. + nflog("set entries: user-defined target %q", name) + return stack.UserChainTarget{Name: name}, nil + } + + case redirectTargetName: + // Redirect target. + if len(optVal) < linux.SizeOfXTRedirectTarget { + return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal)) + } + + if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + + var redirectTarget linux.XTRedirectTarget + buf = optVal[:linux.SizeOfXTRedirectTarget] + binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) + + // Copy linux.XTRedirectTarget to stack.RedirectTarget. + var target stack.RedirectTarget + nfRange := redirectTarget.NfRange + + // RangeSize should be 1. + if nfRange.RangeSize != 1 { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + + // TODO(gvisor.dev/issue/170): Check if the flags are valid. + // Also check if we need to map ports or IP. + // For now, redirect target only supports destination port change. + // Port range and IP range are not supported yet. + if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + target.RangeProtoSpecified = true + + target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) + target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:]) + + // TODO(gvisor.dev/issue/170): Port range is not supported yet. + if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { + return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + } + + // Convert port from big endian to little endian. + port := make([]byte, 2) + binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort) + target.MinPort = binary.LittleEndian.Uint16(port) + + binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort) + target.MaxPort = binary.LittleEndian.Uint16(port) + return target, nil + } + + // Unknown target. + return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet", target.Name.String()) +} + +// JumpTarget implements stack.Target. +type JumpTarget struct { + // Offset is the byte offset of the rule to jump to. It is used for + // marshaling and unmarshaling. + Offset uint32 + + // RuleNum is the rule to jump to. + RuleNum int +} + +// Action implements stack.Target.Action. +func (jt JumpTarget) Action(*stack.PacketBuffer, *stack.ConnTrack, stack.Hook, *stack.GSO, *stack.Route, tcpip.Address) (stack.RuleVerdict, int) { + return stack.RuleJump, jt.RuleNum +} diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go new file mode 100644 index 000000000..0bfd6c1f4 --- /dev/null +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -0,0 +1,130 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" +) + +const matcherNameTCP = "tcp" + +func init() { + registerMatchMaker(tcpMarshaler{}) +} + +// tcpMarshaler implements matchMaker for TCP matching. +type tcpMarshaler struct{} + +// name implements matchMaker.name. +func (tcpMarshaler) name() string { + return matcherNameTCP +} + +// marshal implements matchMaker.marshal. +func (tcpMarshaler) marshal(mr stack.Matcher) []byte { + matcher := mr.(*TCPMatcher) + xttcp := linux.XTTCP{ + SourcePortStart: matcher.sourcePortStart, + SourcePortEnd: matcher.sourcePortEnd, + DestinationPortStart: matcher.destinationPortStart, + DestinationPortEnd: matcher.destinationPortEnd, + } + buf := make([]byte, 0, linux.SizeOfXTTCP) + return marshalEntryMatch(matcherNameTCP, binary.Marshal(buf, usermem.ByteOrder, xttcp)) +} + +// unmarshal implements matchMaker.unmarshal. +func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { + if len(buf) < linux.SizeOfXTTCP { + return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf)) + } + + // For alignment reasons, the match's total size may + // exceed what's strictly necessary to hold matchData. + var matchData linux.XTTCP + binary.Unmarshal(buf[:linux.SizeOfXTTCP], usermem.ByteOrder, &matchData) + nflog("parseMatchers: parsed XTTCP: %+v", matchData) + + if matchData.Option != 0 || + matchData.FlagMask != 0 || + matchData.FlagCompare != 0 || + matchData.InverseFlags != 0 { + return nil, fmt.Errorf("unsupported TCP matcher flags set") + } + + if filter.Protocol != header.TCPProtocolNumber { + return nil, fmt.Errorf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber) + } + + return &TCPMatcher{ + sourcePortStart: matchData.SourcePortStart, + sourcePortEnd: matchData.SourcePortEnd, + destinationPortStart: matchData.DestinationPortStart, + destinationPortEnd: matchData.DestinationPortEnd, + }, nil +} + +// TCPMatcher matches TCP packets and their headers. It implements Matcher. +type TCPMatcher struct { + sourcePortStart uint16 + sourcePortEnd uint16 + destinationPortStart uint16 + destinationPortEnd uint16 +} + +// Name implements Matcher.Name. +func (*TCPMatcher) Name() string { + return matcherNameTCP +} + +// Match implements Matcher.Match. +func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + + if netHeader.TransportProtocol() != header.TCPProtocolNumber { + return false, false + } + + // We dont't match fragments. + if frag := netHeader.FragmentOffset(); frag != 0 { + if frag == 1 { + return false, true + } + return false, false + } + + tcpHeader := header.TCP(pkt.TransportHeader().View()) + if len(tcpHeader) < header.TCPMinimumSize { + // There's no valid TCP header here, so we drop the packet immediately. + return false, true + } + + // Check whether the source and destination ports are within the + // matching range. + if sourcePort := tcpHeader.SourcePort(); sourcePort < tm.sourcePortStart || tm.sourcePortEnd < sourcePort { + return false, false + } + if destinationPort := tcpHeader.DestinationPort(); destinationPort < tm.destinationPortStart || tm.destinationPortEnd < destinationPort { + return false, false + } + + return true, false +} diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go new file mode 100644 index 000000000..7ed05461d --- /dev/null +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -0,0 +1,129 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" +) + +const matcherNameUDP = "udp" + +func init() { + registerMatchMaker(udpMarshaler{}) +} + +// udpMarshaler implements matchMaker for UDP matching. +type udpMarshaler struct{} + +// name implements matchMaker.name. +func (udpMarshaler) name() string { + return matcherNameUDP +} + +// marshal implements matchMaker.marshal. +func (udpMarshaler) marshal(mr stack.Matcher) []byte { + matcher := mr.(*UDPMatcher) + xtudp := linux.XTUDP{ + SourcePortStart: matcher.sourcePortStart, + SourcePortEnd: matcher.sourcePortEnd, + DestinationPortStart: matcher.destinationPortStart, + DestinationPortEnd: matcher.destinationPortEnd, + } + buf := make([]byte, 0, linux.SizeOfXTUDP) + return marshalEntryMatch(matcherNameUDP, binary.Marshal(buf, usermem.ByteOrder, xtudp)) +} + +// unmarshal implements matchMaker.unmarshal. +func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) { + if len(buf) < linux.SizeOfXTUDP { + return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf)) + } + + // For alignment reasons, the match's total size may exceed what's + // strictly necessary to hold matchData. + var matchData linux.XTUDP + binary.Unmarshal(buf[:linux.SizeOfXTUDP], usermem.ByteOrder, &matchData) + nflog("parseMatchers: parsed XTUDP: %+v", matchData) + + if matchData.InverseFlags != 0 { + return nil, fmt.Errorf("unsupported UDP matcher inverse flags set") + } + + if filter.Protocol != header.UDPProtocolNumber { + return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber) + } + + return &UDPMatcher{ + sourcePortStart: matchData.SourcePortStart, + sourcePortEnd: matchData.SourcePortEnd, + destinationPortStart: matchData.DestinationPortStart, + destinationPortEnd: matchData.DestinationPortEnd, + }, nil +} + +// UDPMatcher matches UDP packets and their headers. It implements Matcher. +type UDPMatcher struct { + sourcePortStart uint16 + sourcePortEnd uint16 + destinationPortStart uint16 + destinationPortEnd uint16 +} + +// Name implements Matcher.Name. +func (*UDPMatcher) Name() string { + return matcherNameUDP +} + +// Match implements Matcher.Match. +func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) { + netHeader := header.IPv4(pkt.NetworkHeader().View()) + + // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved + // into the stack.Check codepath as matchers are added. + if netHeader.TransportProtocol() != header.UDPProtocolNumber { + return false, false + } + + // We dont't match fragments. + if frag := netHeader.FragmentOffset(); frag != 0 { + if frag == 1 { + return false, true + } + return false, false + } + + udpHeader := header.UDP(pkt.TransportHeader().View()) + if len(udpHeader) < header.UDPMinimumSize { + // There's no valid UDP header here, so we drop the packet immediately. + return false, true + } + + // Check whether the source and destination ports are within the + // matching range. + if sourcePort := udpHeader.SourcePort(); sourcePort < um.sourcePortStart || um.sourcePortEnd < sourcePort { + return false, false + } + if destinationPort := udpHeader.DestinationPort(); destinationPort < um.destinationPortStart || um.destinationPortEnd < destinationPort { + return false, false + } + + return true, false +} diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD index f95803f91..0546801bf 100644 --- a/pkg/sentry/socket/netlink/BUILD +++ b/pkg/sentry/socket/netlink/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -7,29 +7,48 @@ go_library( srcs = [ "message.go", "provider.go", + "provider_vfs2.go", "socket.go", + "socket_vfs2.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", "//pkg/binary", + "//pkg/context", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", + "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", + "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", - "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/netlink/port", "//pkg/sentry/socket/unix", "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", + "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", + "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", + ], +) + +go_test( + name = "netlink_test", + size = "small", + srcs = [ + "message_test.go", + ], + deps = [ + ":netlink", + "//pkg/abi/linux", ], ) diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go index ce0a1afd0..0899c61d1 100644 --- a/pkg/sentry/socket/netlink/message.go +++ b/pkg/sentry/socket/netlink/message.go @@ -20,18 +20,19 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/usermem" ) -// alignUp rounds a length up to an alignment. +// alignPad returns the length of padding required for alignment. // // Preconditions: align is a power of two. -func alignUp(length int, align uint) int { - return (length + int(align) - 1) &^ (int(align) - 1) +func alignPad(length int, align uint) int { + return binary.AlignUp(length, align) - length } // Message contains a complete serialized netlink message. type Message struct { + hdr linux.NetlinkMessageHeader buf []byte } @@ -40,10 +41,86 @@ type Message struct { // The header length will be updated by Finalize. func NewMessage(hdr linux.NetlinkMessageHeader) *Message { return &Message{ + hdr: hdr, buf: binary.Marshal(nil, usermem.ByteOrder, hdr), } } +// ParseMessage parses the first message seen at buf, returning the rest of the +// buffer. If message is malformed, ok of false is returned. For last message, +// padding check is loose, if there isn't enought padding, whole buf is consumed +// and ok is set to true. +func ParseMessage(buf []byte) (msg *Message, rest []byte, ok bool) { + b := BytesView(buf) + + hdrBytes, ok := b.Extract(linux.NetlinkMessageHeaderSize) + if !ok { + return + } + var hdr linux.NetlinkMessageHeader + binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr) + + // Msg portion. + totalMsgLen := int(hdr.Length) + _, ok = b.Extract(totalMsgLen - linux.NetlinkMessageHeaderSize) + if !ok { + return + } + + // Padding. + numPad := alignPad(totalMsgLen, linux.NLMSG_ALIGNTO) + // Linux permits the last message not being aligned, just consume all of it. + // Ref: net/netlink/af_netlink.c:netlink_rcv_skb + if numPad > len(b) { + numPad = len(b) + } + _, ok = b.Extract(numPad) + if !ok { + return + } + + return &Message{ + hdr: hdr, + buf: buf[:totalMsgLen], + }, []byte(b), true +} + +// Header returns the header of this message. +func (m *Message) Header() linux.NetlinkMessageHeader { + return m.hdr +} + +// GetData unmarshals the payload message header from this netlink message, and +// returns the attributes portion. +func (m *Message) GetData(msg interface{}) (AttrsView, bool) { + b := BytesView(m.buf) + + _, ok := b.Extract(linux.NetlinkMessageHeaderSize) + if !ok { + return nil, false + } + + size := int(binary.Size(msg)) + msgBytes, ok := b.Extract(size) + if !ok { + return nil, false + } + binary.Unmarshal(msgBytes, usermem.ByteOrder, msg) + + numPad := alignPad(linux.NetlinkMessageHeaderSize+size, linux.NLMSG_ALIGNTO) + // Linux permits the last message not being aligned, just consume all of it. + // Ref: net/netlink/af_netlink.c:netlink_rcv_skb + if numPad > len(b) { + numPad = len(b) + } + _, ok = b.Extract(numPad) + if !ok { + return nil, false + } + + return AttrsView(b), true +} + // Finalize returns the []byte containing the entire message, with the total // length set in the message header. The Message must not be modified after // calling Finalize. @@ -54,7 +131,7 @@ func (m *Message) Finalize() []byte { // Align the message. Note that the message length in the header (set // above) is the useful length of the message, not the total aligned // length. See net/netlink/af_netlink.c:__nlmsg_put. - aligned := alignUp(len(m.buf), linux.NLMSG_ALIGNTO) + aligned := binary.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO) m.putZeros(aligned - len(m.buf)) return m.buf } @@ -89,7 +166,7 @@ func (m *Message) PutAttr(atype uint16, v interface{}) { m.Put(v) // Align the attribute. - aligned := alignUp(l, linux.NLA_ALIGNTO) + aligned := binary.AlignUp(l, linux.NLA_ALIGNTO) m.putZeros(aligned - l) } @@ -106,7 +183,7 @@ func (m *Message) PutAttrString(atype uint16, s string) { m.putZeros(1) // Align the attribute. - aligned := alignUp(l, linux.NLA_ALIGNTO) + aligned := binary.AlignUp(l, linux.NLA_ALIGNTO) m.putZeros(aligned - l) } @@ -157,3 +234,48 @@ func (ms *MessageSet) AddMessage(hdr linux.NetlinkMessageHeader) *Message { ms.Messages = append(ms.Messages, m) return m } + +// AttrsView is a view into the attributes portion of a netlink message. +type AttrsView []byte + +// Empty returns whether there is no attribute left in v. +func (v AttrsView) Empty() bool { + return len(v) == 0 +} + +// ParseFirst parses first netlink attribute at the beginning of v. +func (v AttrsView) ParseFirst() (hdr linux.NetlinkAttrHeader, value []byte, rest AttrsView, ok bool) { + b := BytesView(v) + + hdrBytes, ok := b.Extract(linux.NetlinkAttrHeaderSize) + if !ok { + return + } + binary.Unmarshal(hdrBytes, usermem.ByteOrder, &hdr) + + value, ok = b.Extract(int(hdr.Length) - linux.NetlinkAttrHeaderSize) + if !ok { + return + } + + _, ok = b.Extract(alignPad(int(hdr.Length), linux.NLA_ALIGNTO)) + if !ok { + return + } + + return hdr, value, AttrsView(b), ok +} + +// BytesView supports extracting data from a byte slice with bounds checking. +type BytesView []byte + +// Extract removes the first n bytes from v and returns it. If n is out of +// bounds, it returns false. +func (v *BytesView) Extract(n int) ([]byte, bool) { + if n < 0 || n > len(*v) { + return nil, false + } + extracted := (*v)[:n] + *v = (*v)[n:] + return extracted, true +} diff --git a/pkg/sentry/socket/netlink/message_test.go b/pkg/sentry/socket/netlink/message_test.go new file mode 100644 index 000000000..ef13d9386 --- /dev/null +++ b/pkg/sentry/socket/netlink/message_test.go @@ -0,0 +1,312 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package message_test + +import ( + "bytes" + "reflect" + "testing" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/socket/netlink" +) + +type dummyNetlinkMsg struct { + Foo uint16 +} + +func TestParseMessage(t *testing.T) { + tests := []struct { + desc string + input []byte + + header linux.NetlinkMessageHeader + dataMsg *dummyNetlinkMsg + restLen int + ok bool + }{ + { + desc: "valid", + input: []byte{ + 0x14, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding + }, + header: linux.NetlinkMessageHeader{ + Length: 20, + Type: 1, + Flags: 2, + Seq: 3, + PortID: 4, + }, + dataMsg: &dummyNetlinkMsg{ + Foo: 0x3130, + }, + restLen: 0, + ok: true, + }, + { + desc: "valid with next message", + input: []byte{ + 0x14, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding + 0xFF, // Next message (rest) + }, + header: linux.NetlinkMessageHeader{ + Length: 20, + Type: 1, + Flags: 2, + Seq: 3, + PortID: 4, + }, + dataMsg: &dummyNetlinkMsg{ + Foo: 0x3130, + }, + restLen: 1, + ok: true, + }, + { + desc: "valid for last message without padding", + input: []byte{ + 0x12, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, // Data message + }, + header: linux.NetlinkMessageHeader{ + Length: 18, + Type: 1, + Flags: 2, + Seq: 3, + PortID: 4, + }, + dataMsg: &dummyNetlinkMsg{ + Foo: 0x3130, + }, + restLen: 0, + ok: true, + }, + { + desc: "valid for last message not to be aligned", + input: []byte{ + 0x13, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, // Data message + 0x00, // Excessive 1 byte permitted at end + }, + header: linux.NetlinkMessageHeader{ + Length: 19, + Type: 1, + Flags: 2, + Seq: 3, + PortID: 4, + }, + dataMsg: &dummyNetlinkMsg{ + Foo: 0x3130, + }, + restLen: 0, + ok: true, + }, + { + desc: "header.Length too short", + input: []byte{ + 0x04, 0x00, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding + }, + ok: false, + }, + { + desc: "header.Length too long", + input: []byte{ + 0xFF, 0xFF, 0x00, 0x00, // Length + 0x01, 0x00, // Type + 0x02, 0x00, // Flags + 0x03, 0x00, 0x00, 0x00, // Seq + 0x04, 0x00, 0x00, 0x00, // PortID + 0x30, 0x31, 0x00, 0x00, // Data message with 2 bytes padding + }, + ok: false, + }, + { + desc: "header incomplete", + input: []byte{ + 0x04, 0x00, 0x00, 0x00, // Length + }, + ok: false, + }, + { + desc: "empty message", + input: []byte{}, + ok: false, + }, + } + for _, test := range tests { + msg, rest, ok := netlink.ParseMessage(test.input) + if ok != test.ok { + t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok) + continue + } + if !test.ok { + continue + } + if !reflect.DeepEqual(msg.Header(), test.header) { + t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, msg.Header(), test.header) + } + + dataMsg := &dummyNetlinkMsg{} + _, dataOk := msg.GetData(dataMsg) + if !dataOk { + t.Errorf("%v: GetData.ok = %v, want = true", test.desc, dataOk) + } else if !reflect.DeepEqual(dataMsg, test.dataMsg) { + t.Errorf("%v: GetData.msg = %+v, want = %+v", test.desc, dataMsg, test.dataMsg) + } + + if got, want := rest, test.input[len(test.input)-test.restLen:]; !bytes.Equal(got, want) { + t.Errorf("%v: got rest = %v, want = %v", test.desc, got, want) + } + } +} + +func TestAttrView(t *testing.T) { + tests := []struct { + desc string + input []byte + + // Outputs for ParseFirst. + hdr linux.NetlinkAttrHeader + value []byte + restLen int + ok bool + + // Outputs for Empty. + isEmpty bool + }{ + { + desc: "valid", + input: []byte{ + 0x06, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x00, 0x00, // Data with 2 bytes padding + }, + hdr: linux.NetlinkAttrHeader{ + Length: 6, + Type: 1, + }, + value: []byte{0x30, 0x31}, + restLen: 0, + ok: true, + isEmpty: false, + }, + { + desc: "at alignment", + input: []byte{ + 0x08, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x32, 0x33, // Data + }, + hdr: linux.NetlinkAttrHeader{ + Length: 8, + Type: 1, + }, + value: []byte{0x30, 0x31, 0x32, 0x33}, + restLen: 0, + ok: true, + isEmpty: false, + }, + { + desc: "at alignment with rest data", + input: []byte{ + 0x08, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x32, 0x33, // Data + 0xFF, 0xFE, // Rest data + }, + hdr: linux.NetlinkAttrHeader{ + Length: 8, + Type: 1, + }, + value: []byte{0x30, 0x31, 0x32, 0x33}, + restLen: 2, + ok: true, + isEmpty: false, + }, + { + desc: "hdr.Length too long", + input: []byte{ + 0xFF, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x32, 0x33, // Data + }, + ok: false, + isEmpty: false, + }, + { + desc: "hdr.Length too short", + input: []byte{ + 0x01, 0x00, // Length + 0x01, 0x00, // Type + 0x30, 0x31, 0x32, 0x33, // Data + }, + ok: false, + isEmpty: false, + }, + { + desc: "empty", + input: []byte{}, + ok: false, + isEmpty: true, + }, + } + for _, test := range tests { + attrs := netlink.AttrsView(test.input) + + // Test ParseFirst(). + hdr, value, rest, ok := attrs.ParseFirst() + if ok != test.ok { + t.Errorf("%v: got ok = %v, want = %v", test.desc, ok, test.ok) + } else if test.ok { + if !reflect.DeepEqual(hdr, test.hdr) { + t.Errorf("%v: got hdr = %+v, want = %+v", test.desc, hdr, test.hdr) + } + if !bytes.Equal(value, test.value) { + t.Errorf("%v: got value = %v, want = %v", test.desc, value, test.value) + } + if wantRest := test.input[len(test.input)-test.restLen:]; !bytes.Equal(rest, wantRest) { + t.Errorf("%v: got rest = %v, want = %v", test.desc, rest, wantRest) + } + } + + // Test Empty(). + if got, want := attrs.Empty(), test.isEmpty; got != want { + t.Errorf("%v: got empty = %v, want = %v", test.desc, got, want) + } + } +} diff --git a/pkg/sentry/socket/netlink/port/BUILD b/pkg/sentry/socket/netlink/port/BUILD index 463544c1a..3a22923d8 100644 --- a/pkg/sentry/socket/netlink/port/BUILD +++ b/pkg/sentry/socket/netlink/port/BUILD @@ -1,17 +1,16 @@ -load("@io_bazel_rules_go//go:def.bzl", "go_test") -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) go_library( name = "port", srcs = ["port.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port", visibility = ["//pkg/sentry:internal"], + deps = ["//pkg/sync"], ) go_test( name = "port_test", srcs = ["port_test.go"], - embed = [":port"], + library = ":port", ) diff --git a/pkg/sentry/socket/netlink/port/port.go b/pkg/sentry/socket/netlink/port/port.go index e9d3275b1..2cd3afc22 100644 --- a/pkg/sentry/socket/netlink/port/port.go +++ b/pkg/sentry/socket/netlink/port/port.go @@ -24,7 +24,8 @@ import ( "fmt" "math" "math/rand" - "sync" + + "gvisor.dev/gvisor/pkg/sync" ) // maxPorts is a sanity limit on the maximum number of ports to allocate per diff --git a/pkg/sentry/socket/netlink/provider.go b/pkg/sentry/socket/netlink/provider.go index 689cad997..31e374833 100644 --- a/pkg/sentry/socket/netlink/provider.go +++ b/pkg/sentry/socket/netlink/provider.go @@ -18,7 +18,7 @@ import ( "fmt" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/socket" @@ -30,12 +30,19 @@ type Protocol interface { // Protocol returns the Linux netlink protocol value. Protocol() int + // CanSend returns true if this protocol may ever send messages. + // + // TODO(gvisor.dev/issue/1119): This is a workaround to allow + // advertising support for otherwise unimplemented features on sockets + // that will never send messages, thus making those features no-ops. + CanSend() bool + // ProcessMessage processes a single message from userspace. // // If err == nil, any messages added to ms will be sent back to the // other end of the socket. Setting ms.Multi will cause an NLMSG_DONE // message to be sent even if ms contains no messages. - ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *MessageSet) *syserr.Error + ProcessMessage(ctx context.Context, msg *Message, ms *MessageSet) *syserr.Error } // Provider is a function that creates a new Protocol for a specific netlink @@ -60,6 +67,8 @@ func RegisterProvider(protocol int, provider Provider) { protocols[protocol] = provider } +// LINT.IfChange + // socketProvider implements socket.Provider. type socketProvider struct { } @@ -88,7 +97,7 @@ func (*socketProvider) Socket(t *kernel.Task, stype linux.SockType, protocol int } d := socket.NewDirent(t, netlinkSocketDevice) - defer d.DecRef() + defer d.DecRef(t) return fs.NewFile(t, d, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, s), nil } @@ -98,7 +107,10 @@ func (*socketProvider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.Fi return nil, nil, syserr.ErrNotSupported } +// LINT.ThenChange(./provider_vfs2.go) + // init registers the socket provider. func init() { socket.RegisterProvider(linux.AF_NETLINK, &socketProvider{}) + socket.RegisterProviderVFS2(linux.AF_NETLINK, &socketProviderVFS2{}) } diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go new file mode 100644 index 000000000..bb205be0d --- /dev/null +++ b/pkg/sentry/socket/netlink/provider_vfs2.go @@ -0,0 +1,69 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netlink + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" +) + +// socketProviderVFS2 implements socket.Provider. +type socketProviderVFS2 struct { +} + +// Socket implements socket.Provider.Socket. +func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) { + // Netlink sockets must be specified as datagram or raw, but they + // behave the same regardless of type. + if stype != linux.SOCK_DGRAM && stype != linux.SOCK_RAW { + return nil, syserr.ErrSocketNotSupported + } + + provider, ok := protocols[protocol] + if !ok { + return nil, syserr.ErrProtocolNotSupported + } + + p, err := provider(t) + if err != nil { + return nil, err + } + + s, err := NewVFS2(t, stype, p) + if err != nil { + return nil, err + } + + vfsfd := &s.vfsfd + mnt := t.Kernel().SocketMount() + d := sockfs.NewDentry(t.Credentials(), mnt) + if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{ + DenyPRead: true, + DenyPWrite: true, + UseDentryMetadata: true, + }); err != nil { + return nil, syserr.FromError(err) + } + return vfsfd, nil +} + +// Pair implements socket.Provider.Pair by returning an error. +func (*socketProviderVFS2) Pair(*kernel.Task, linux.SockType, int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) { + // Netlink sockets never supports creating socket pairs. + return nil, nil, syserr.ErrNotSupported +} diff --git a/pkg/sentry/socket/netlink/route/BUILD b/pkg/sentry/socket/netlink/route/BUILD index 1d4912753..93127398d 100644 --- a/pkg/sentry/socket/netlink/route/BUILD +++ b/pkg/sentry/socket/netlink/route/BUILD @@ -1,15 +1,16 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) go_library( name = "route", - srcs = ["protocol.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netlink/route", + srcs = [ + "protocol.go", + ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", - "//pkg/sentry/context", + "//pkg/context", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index cc70ac237..c84d8bd7c 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -17,9 +17,10 @@ package route import ( "bytes" + "syscall" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -61,8 +62,13 @@ func (p *Protocol) Protocol() int { return linux.NETLINK_ROUTE } -// dumpLinks handles RTM_GETLINK + NLM_F_DUMP requests. -func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { +// CanSend implements netlink.Protocol.CanSend. +func (p *Protocol) CanSend() bool { + return true +} + +// dumpLinks handles RTM_GETLINK dump requests. +func (p *Protocol) dumpLinks(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { // NLM_F_DUMP + RTM_GETLINK messages are supposed to include an // ifinfomsg. However, Linux <3.9 only checked for rtgenmsg, and some // userspace applications (including glibc) still include rtgenmsg. @@ -86,38 +92,105 @@ func (p *Protocol) dumpLinks(ctx context.Context, hdr linux.NetlinkMessageHeader return nil } - for id, i := range stack.Interfaces() { - m := ms.AddMessage(linux.NetlinkMessageHeader{ - Type: linux.RTM_NEWLINK, - }) + for idx, i := range stack.Interfaces() { + addNewLinkMessage(ms, idx, i) + } - m.Put(linux.InterfaceInfoMessage{ - Family: linux.AF_UNSPEC, - Type: i.DeviceType, - Index: id, - Flags: i.Flags, - }) + return nil +} - m.PutAttrString(linux.IFLA_IFNAME, i.Name) - m.PutAttr(linux.IFLA_MTU, i.MTU) +// getLinks handles RTM_GETLINK requests. +func (p *Protocol) getLink(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network devices. + return nil + } - mac := make([]byte, 6) - brd := mac - if len(i.Addr) > 0 { - mac = i.Addr - brd = bytes.Repeat([]byte{0xff}, len(i.Addr)) + // Parse message. + var ifi linux.InterfaceInfoMessage + attrs, ok := msg.GetData(&ifi) + if !ok { + return syserr.ErrInvalidArgument + } + + // Parse attributes. + var byName []byte + for !attrs.Empty() { + ahdr, value, rest, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument } - m.PutAttr(linux.IFLA_ADDRESS, mac) - m.PutAttr(linux.IFLA_BROADCAST, brd) + attrs = rest - // TODO(gvisor.dev/issue/578): There are many more attributes. + switch ahdr.Type { + case linux.IFLA_IFNAME: + if len(value) < 1 { + return syserr.ErrInvalidArgument + } + byName = value[:len(value)-1] + + // TODO(gvisor.dev/issue/578): Support IFLA_EXT_MASK. + } } + found := false + for idx, i := range stack.Interfaces() { + switch { + case ifi.Index > 0: + if idx != ifi.Index { + continue + } + case byName != nil: + if string(byName) != i.Name { + continue + } + default: + // Criteria not specified. + return syserr.ErrInvalidArgument + } + + addNewLinkMessage(ms, idx, i) + found = true + break + } + if !found { + return syserr.ErrNoDevice + } return nil } -// dumpAddrs handles RTM_GETADDR + NLM_F_DUMP requests. -func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { +// addNewLinkMessage appends RTM_NEWLINK message for the given interface into +// the message set. +func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { + m := ms.AddMessage(linux.NetlinkMessageHeader{ + Type: linux.RTM_NEWLINK, + }) + + m.Put(linux.InterfaceInfoMessage{ + Family: linux.AF_UNSPEC, + Type: i.DeviceType, + Index: idx, + Flags: i.Flags, + }) + + m.PutAttrString(linux.IFLA_IFNAME, i.Name) + m.PutAttr(linux.IFLA_MTU, i.MTU) + + mac := make([]byte, 6) + brd := mac + if len(i.Addr) > 0 { + mac = i.Addr + brd = bytes.Repeat([]byte{0xff}, len(i.Addr)) + } + m.PutAttr(linux.IFLA_ADDRESS, mac) + m.PutAttr(linux.IFLA_BROADCAST, brd) + + // TODO(gvisor.dev/issue/578): There are many more attributes. +} + +// dumpAddrs handles RTM_GETADDR dump requests. +func (p *Protocol) dumpAddrs(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { // RTM_GETADDR dump requests need not contain anything more than the // netlink header and 1 byte protocol family common to all // NETLINK_ROUTE requests. @@ -149,6 +222,7 @@ func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader Index: uint32(id), }) + m.PutAttr(linux.IFA_LOCAL, []byte(a.Addr)) m.PutAttr(linux.IFA_ADDRESS, []byte(a.Addr)) // TODO(gvisor.dev/issue/578): There are many more attributes. @@ -158,22 +232,136 @@ func (p *Protocol) dumpAddrs(ctx context.Context, hdr linux.NetlinkMessageHeader return nil } -// dumpRoutes handles RTM_GETROUTE + NLM_F_DUMP requests. -func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { +// commonPrefixLen reports the length of the longest IP address prefix. +// This is a simplied version from Golang's src/net/addrselect.go. +func commonPrefixLen(a, b []byte) (cpl int) { + for len(a) > 0 { + if a[0] == b[0] { + cpl += 8 + a = a[1:] + b = b[1:] + continue + } + bits := 8 + ab, bb := a[0], b[0] + for { + ab >>= 1 + bb >>= 1 + bits-- + if ab == bb { + cpl += bits + return + } + } + } + return +} + +// fillRoute returns the Route using LPM algorithm. Refer to Linux's +// net/ipv4/route.c:rt_fill_info(). +func fillRoute(routes []inet.Route, addr []byte) (inet.Route, *syserr.Error) { + family := uint8(linux.AF_INET) + if len(addr) != 4 { + family = linux.AF_INET6 + } + + idx := -1 // Index of the Route rule to be returned. + idxDef := -1 // Index of the default route rule. + prefix := 0 // Current longest prefix. + for i, route := range routes { + if route.Family != family { + continue + } + + if len(route.GatewayAddr) > 0 && route.DstLen == 0 { + idxDef = i + continue + } + + cpl := commonPrefixLen(addr, route.DstAddr) + if cpl < int(route.DstLen) { + continue + } + cpl = int(route.DstLen) + if cpl > prefix { + idx = i + prefix = cpl + } + } + if idx == -1 { + idx = idxDef + } + if idx == -1 { + return inet.Route{}, syserr.ErrNoRoute + } + + route := routes[idx] + if family == linux.AF_INET { + route.DstLen = 32 + } else { + route.DstLen = 128 + } + route.DstAddr = addr + route.Flags |= linux.RTM_F_CLONED // This route is cloned. + return route, nil +} + +// parseForDestination parses a message as format of RouteMessage-RtAttr-dst. +func parseForDestination(msg *netlink.Message) ([]byte, *syserr.Error) { + var rtMsg linux.RouteMessage + attrs, ok := msg.GetData(&rtMsg) + if !ok { + return nil, syserr.ErrInvalidArgument + } + // iproute2 added the RTM_F_LOOKUP_TABLE flag in version v4.4.0. See + // commit bc234301af12. Note we don't check this flag for backward + // compatibility. + if rtMsg.Flags != 0 && rtMsg.Flags != linux.RTM_F_LOOKUP_TABLE { + return nil, syserr.ErrNotSupported + } + + // Expect first attribute is RTA_DST. + if hdr, value, _, ok := attrs.ParseFirst(); ok && hdr.Type == linux.RTA_DST { + return value, nil + } + return nil, syserr.ErrInvalidArgument +} + +// dumpRoutes handles RTM_GETROUTE requests. +func (p *Protocol) dumpRoutes(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { // RTM_GETROUTE dump requests need not contain anything more than the // netlink header and 1 byte protocol family common to all // NETLINK_ROUTE requests. - // We always send back an NLMSG_DONE. - ms.Multi = true - stack := inet.StackFromContext(ctx) if stack == nil { // No network routes. return nil } - for _, rt := range stack.RouteTable() { + hdr := msg.Header() + routeTables := stack.RouteTable() + + if hdr.Flags == linux.NLM_F_REQUEST { + dst, err := parseForDestination(msg) + if err != nil { + return err + } + route, err := fillRoute(routeTables, dst) + if err != nil { + // TODO(gvisor.dev/issue/1237): return NLMSG_ERROR with ENETUNREACH. + return syserr.ErrNotSupported + } + routeTables = append([]inet.Route{}, route) + } else if hdr.Flags&linux.NLM_F_DUMP == linux.NLM_F_DUMP { + // We always send back an NLMSG_DONE. + ms.Multi = true + } else { + // TODO(b/68878065): Only above cases are supported. + return syserr.ErrNotSupported + } + + for _, rt := range routeTables { m := ms.AddMessage(linux.NetlinkMessageHeader{ Type: linux.RTM_NEWROUTE, }) @@ -214,10 +402,55 @@ func (p *Protocol) dumpRoutes(ctx context.Context, hdr linux.NetlinkMessageHeade return nil } +// newAddr handles RTM_NEWADDR requests. +func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network stack. + return syserr.ErrProtocolNotSupported + } + + var ifa linux.InterfaceAddrMessage + attrs, ok := msg.GetData(&ifa) + if !ok { + return syserr.ErrInvalidArgument + } + + for !attrs.Empty() { + ahdr, value, rest, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument + } + attrs = rest + + switch ahdr.Type { + case linux.IFA_LOCAL: + err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{ + Family: ifa.Family, + PrefixLen: ifa.PrefixLen, + Flags: ifa.Flags, + Addr: value, + }) + if err == syscall.EEXIST { + flags := msg.Header().Flags + if flags&linux.NLM_F_EXCL != 0 { + return syserr.ErrExists + } + } else if err != nil { + return syserr.ErrInvalidArgument + } + } + } + return nil +} + // ProcessMessage implements netlink.Protocol.ProcessMessage. -func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageHeader, data []byte, ms *netlink.MessageSet) *syserr.Error { +func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + hdr := msg.Header() + // All messages start with a 1 byte protocol family. - if len(data) < 1 { + var family uint8 + if _, ok := msg.GetData(&family); !ok { // Linux ignores messages missing the protocol family. See // net/core/rtnetlink.c:rtnetlink_rcv_msg. return nil @@ -231,22 +464,32 @@ func (p *Protocol) ProcessMessage(ctx context.Context, hdr linux.NetlinkMessageH } } - // TODO(b/68878065): Only the dump variant of the types below are - // supported. - if hdr.Flags&linux.NLM_F_DUMP != linux.NLM_F_DUMP { - return syserr.ErrNotSupported - } - - switch hdr.Type { - case linux.RTM_GETLINK: - return p.dumpLinks(ctx, hdr, data, ms) - case linux.RTM_GETADDR: - return p.dumpAddrs(ctx, hdr, data, ms) - case linux.RTM_GETROUTE: - return p.dumpRoutes(ctx, hdr, data, ms) - default: - return syserr.ErrNotSupported + if hdr.Flags&linux.NLM_F_DUMP == linux.NLM_F_DUMP { + // TODO(b/68878065): Only the dump variant of the types below are + // supported. + switch hdr.Type { + case linux.RTM_GETLINK: + return p.dumpLinks(ctx, msg, ms) + case linux.RTM_GETADDR: + return p.dumpAddrs(ctx, msg, ms) + case linux.RTM_GETROUTE: + return p.dumpRoutes(ctx, msg, ms) + default: + return syserr.ErrNotSupported + } + } else if hdr.Flags&linux.NLM_F_REQUEST == linux.NLM_F_REQUEST { + switch hdr.Type { + case linux.RTM_GETLINK: + return p.getLink(ctx, msg, ms) + case linux.RTM_GETROUTE: + return p.dumpRoutes(ctx, msg, ms) + case linux.RTM_NEWADDR: + return p.newAddr(ctx, msg, ms) + default: + return syserr.ErrNotSupported + } } + return syserr.ErrNotSupported } // init registers the NETLINK_ROUTE provider. diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go index b2732ca29..68a9b9a96 100644 --- a/pkg/sentry/socket/netlink/socket.go +++ b/pkg/sentry/socket/netlink/socket.go @@ -17,27 +17,29 @@ package netlink import ( "math" - "sync" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netlink/port" "gvisor.dev/gvisor/pkg/sentry/socket/unix" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) const sizeOfInt32 int = 4 @@ -53,15 +55,19 @@ const ( maxSendBufferSize = 4 << 20 // 4MB ) +var errNoFilter = syserr.New("no filter attached", linux.ENOENT) + // netlinkSocketDevice is the netlink socket virtual device. var netlinkSocketDevice = device.NewAnonDevice() +// LINT.IfChange + // Socket is the base socket type for netlink sockets. // // This implementation only supports userspace sending and receiving messages // to/from the kernel. // -// Socket implements socket.Socket. +// Socket implements socket.Socket and transport.Credentialer. // // +stateify savable type Socket struct { @@ -72,6 +78,14 @@ type Socket struct { fsutil.FileNoSplice `state:"nosave"` fsutil.FileNoopFlush `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + socketOpsCommon +} + +// socketOpsCommon contains the socket operations common to VFS1 and VFS2. +// +// +stateify savable +type socketOpsCommon struct { socket.SendReceiveTimeout // ports provides netlink port allocation. @@ -104,9 +118,19 @@ type Socket struct { // sendBufferSize is the send buffer "size". We don't actually have a // fixed buffer but only consume this many bytes. sendBufferSize uint32 + + // passcred indicates if this socket wants SCM credentials. + passcred bool + + // filter indicates that this socket has a BPF filter "installed". + // + // TODO(gvisor.dev/issue/1119): We don't actually support filtering, + // this is just bookkeeping for tracking add/remove. + filter bool } var _ socket.Socket = (*Socket)(nil) +var _ transport.Credentialer = (*Socket)(nil) // NewSocket creates a new Socket. func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socket, *syserr.Error) { @@ -116,31 +140,33 @@ func NewSocket(t *kernel.Task, skType linux.SockType, protocol Protocol) (*Socke // Bind the endpoint for good measure so we can connect to it. The // bound address will never be exposed. if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil { - ep.Close() + ep.Close(t) return nil, err } // Create a connection from which the kernel can write messages. connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t) if err != nil { - ep.Close() + ep.Close(t) return nil, err } return &Socket{ - ports: t.Kernel().NetlinkPorts(), - protocol: protocol, - skType: skType, - ep: ep, - connection: connection, - sendBufferSize: defaultSendBufferSize, + socketOpsCommon: socketOpsCommon{ + ports: t.Kernel().NetlinkPorts(), + protocol: protocol, + skType: skType, + ep: ep, + connection: connection, + sendBufferSize: defaultSendBufferSize, + }, }, nil } // Release implements fs.FileOperations.Release. -func (s *Socket) Release() { - s.connection.Release() - s.ep.Close() +func (s *socketOpsCommon) Release(ctx context.Context) { + s.connection.Release(ctx) + s.ep.Close(ctx) if s.bound { s.ports.Release(s.protocol.Protocol(), s.portID) @@ -148,7 +174,7 @@ func (s *Socket) Release() { } // Readiness implements waiter.Waitable.Readiness. -func (s *Socket) Readiness(mask waiter.EventMask) waiter.EventMask { +func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { // ep holds messages to be read and thus handles EventIn readiness. ready := s.ep.Readiness(mask) @@ -162,16 +188,32 @@ func (s *Socket) Readiness(mask waiter.EventMask) waiter.EventMask { } // EventRegister implements waiter.Waitable.EventRegister. -func (s *Socket) EventRegister(e *waiter.Entry, mask waiter.EventMask) { +func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { s.ep.EventRegister(e, mask) // Writable readiness never changes, so no registration is needed. } // EventUnregister implements waiter.Waitable.EventUnregister. -func (s *Socket) EventUnregister(e *waiter.Entry) { +func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { s.ep.EventUnregister(e) } +// Passcred implements transport.Credentialer.Passcred. +func (s *socketOpsCommon) Passcred() bool { + s.mu.Lock() + passcred := s.passcred + s.mu.Unlock() + return passcred +} + +// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. +func (s *socketOpsCommon) ConnectedPasscred() bool { + // This socket is connected to the kernel, which doesn't need creds. + // + // This is arbitrary, as ConnectedPasscred on this type has no callers. + return false +} + // Ioctl implements fs.FileOperations.Ioctl. func (*Socket) Ioctl(context.Context, *fs.File, usermem.IO, arch.SyscallArguments) (uintptr, error) { // TODO(b/68878065): no ioctls supported. @@ -199,7 +241,7 @@ func ExtractSockAddr(b []byte) (*linux.SockAddrNetlink, *syserr.Error) { // port of 0 defaults to the ThreadGroup ID. // // Preconditions: mu is held. -func (s *Socket) bindPort(t *kernel.Task, port int32) *syserr.Error { +func (s *socketOpsCommon) bindPort(t *kernel.Task, port int32) *syserr.Error { if s.bound { // Re-binding is only allowed if the port doesn't change. if port != s.portID { @@ -223,7 +265,7 @@ func (s *Socket) bindPort(t *kernel.Task, port int32) *syserr.Error { } // Bind implements socket.Socket.Bind. -func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { +func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { a, err := ExtractSockAddr(sockaddr) if err != nil { return err @@ -241,7 +283,7 @@ func (s *Socket) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { } // Connect implements socket.Socket.Connect. -func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { +func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { a, err := ExtractSockAddr(sockaddr) if err != nil { return err @@ -272,25 +314,25 @@ func (s *Socket) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr } // Accept implements socket.Socket.Accept. -func (s *Socket) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { // Netlink sockets never support accept. return 0, nil, 0, syserr.ErrNotSupported } // Listen implements socket.Socket.Listen. -func (s *Socket) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { // Netlink sockets never support listen. return syserr.ErrNotSupported } // Shutdown implements socket.Socket.Shutdown. -func (s *Socket) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { // Netlink sockets never support shutdown. return syserr.ErrNotSupported } // GetSockOpt implements socket.Socket.GetSockOpt. -func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *socketOpsCommon) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: switch name { @@ -300,18 +342,31 @@ func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem. } s.mu.Lock() defer s.mu.Unlock() - return int32(s.sendBufferSize), nil + sendBufferSizeP := primitive.Int32(s.sendBufferSize) + return &sendBufferSizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } // We don't have limit on receiving size. - return int32(math.MaxInt32), nil + recvBufferSizeP := primitive.Int32(math.MaxInt32) + return &recvBufferSizeP, nil + + case linux.SO_PASSCRED: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + var passcred primitive.Int32 + if s.Passcred() { + passcred = 1 + } + return &passcred, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) } + case linux.SOL_NETLINK: switch name { case linux.NETLINK_BROADCAST_ERROR, @@ -330,7 +385,7 @@ func (s *Socket) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem. } // SetSockOpt implements socket.Socket.SetSockOpt. -func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { +func (s *socketOpsCommon) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { switch level { case linux.SOL_SOCKET: switch name { @@ -348,6 +403,7 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy s.sendBufferSize = size s.mu.Unlock() return nil + case linux.SO_RCVBUF: if len(opt) < sizeOfInt32 { return syserr.ErrInvalidArgument @@ -355,6 +411,52 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy // We don't have limit on receiving size. So just accept anything as // valid for compatibility. return nil + + case linux.SO_PASSCRED: + if len(opt) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + passcred := usermem.ByteOrder.Uint32(opt) + + s.mu.Lock() + s.passcred = passcred != 0 + s.mu.Unlock() + return nil + + case linux.SO_ATTACH_FILTER: + // TODO(gvisor.dev/issue/1119): We don't actually + // support filtering. If this socket can't ever send + // messages, then there is nothing to filter and we can + // advertise support. Otherwise, be conservative and + // return an error. + if s.protocol.CanSend() { + socket.SetSockOptEmitUnimplementedEvent(t, name) + return syserr.ErrProtocolNotAvailable + } + + s.mu.Lock() + s.filter = true + s.mu.Unlock() + return nil + + case linux.SO_DETACH_FILTER: + // TODO(gvisor.dev/issue/1119): See above. + if s.protocol.CanSend() { + socket.SetSockOptEmitUnimplementedEvent(t, name) + return syserr.ErrProtocolNotAvailable + } + + s.mu.Lock() + filter := s.filter + s.filter = false + s.mu.Unlock() + + if !filter { + return errNoFilter + } + + return nil + default: socket.SetSockOptEmitUnimplementedEvent(t, name) } @@ -380,7 +482,7 @@ func (s *Socket) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *sy } // GetSockName implements socket.Socket.GetSockName. -func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { s.mu.Lock() defer s.mu.Unlock() @@ -392,7 +494,7 @@ func (s *Socket) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Er } // GetPeerName implements socket.Socket.GetPeerName. -func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { sa := &linux.SockAddrNetlink{ Family: linux.AF_NETLINK, // TODO(b/68878065): Support non-kernel peers. For now the peer @@ -403,7 +505,7 @@ func (s *Socket) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Er } // RecvMsg implements socket.Socket.RecvMsg. -func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { from := &linux.SockAddrNetlink{ Family: linux.AF_NETLINK, PortID: 0, @@ -413,29 +515,29 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have trunc := flags&linux.MSG_TRUNC != 0 r := unix.EndpointReader{ + Ctx: t, Endpoint: s.ep, Peek: flags&linux.MSG_PEEK != 0, } + doRead := func() (int64, error) { + return dst.CopyOutFrom(t, &r) + } + // If MSG_TRUNC is set with a zero byte destination then we still need // to read the message and discard it, or in the case where MSG_PEEK is // set, leave it be. In both cases the full message length must be - // returned. However, the memory manager for the destination will not read - // the endpoint if the destination is zero length. - // - // In order for the endpoint to be read when the destination size is zero, - // we must cause a read of the endpoint by using a separate fake zero - // length block sequence and calling the EndpointReader directly. + // returned. if trunc && dst.Addrs.NumBytes() == 0 { - // Perform a read to a zero byte block sequence. We can ignore the - // original destination since it was zero bytes. The length returned by - // ReadToBlocks is ignored and we return the full message length to comply - // with MSG_TRUNC. - _, err := r.ReadToBlocks(safemem.BlockSeqOf(safemem.BlockFromSafeSlice(make([]byte, 0)))) - return int(r.MsgSize), linux.MSG_TRUNC, from, fromLen, socket.ControlMessages{}, syserr.FromError(err) + doRead = func() (int64, error) { + err := r.Truncate() + // Always return zero for bytes read since the destination size is + // zero. + return 0, err + } } - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { + if n, err := doRead(); err != syserror.ErrWouldBlock || flags&linux.MSG_DONTWAIT != 0 { var mflags int if n < int64(r.MsgSize) { mflags |= linux.MSG_TRUNC @@ -453,7 +555,7 @@ func (s *Socket) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, have defer s.EventUnregister(&e) for { - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock { + if n, err := doRead(); err != syserror.ErrWouldBlock { var mflags int if n < int64(r.MsgSize) { mflags |= linux.MSG_TRUNC @@ -483,18 +585,43 @@ func (s *Socket) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ }) } +// kernelSCM implements control.SCMCredentials with credentials that represent +// the kernel itself rather than a Task. +// +// +stateify savable +type kernelSCM struct{} + +// Equals implements transport.CredentialsControlMessage.Equals. +func (kernelSCM) Equals(oc transport.CredentialsControlMessage) bool { + _, ok := oc.(kernelSCM) + return ok +} + +// Credentials implements control.SCMCredentials.Credentials. +func (kernelSCM) Credentials(*kernel.Task) (kernel.ThreadID, auth.UID, auth.GID) { + return 0, auth.RootUID, auth.RootGID +} + +// kernelCreds is the concrete version of kernelSCM used in all creds. +var kernelCreds = &kernelSCM{} + // sendResponse sends the response messages in ms back to userspace. -func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error { +func (s *socketOpsCommon) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error { // Linux combines multiple netlink messages into a single datagram. bufs := make([][]byte, 0, len(ms.Messages)) for _, m := range ms.Messages { bufs = append(bufs, m.Finalize()) } + // All messages are from the kernel. + cms := transport.ControlMessages{ + Credentials: kernelCreds, + } + if len(bufs) > 0 { // RecvMsg never receives the address, so we don't need to send // one. - _, notify, err := s.connection.Send(bufs, transport.ControlMessages{}, tcpip.FullAddress{}) + _, notify, err := s.connection.Send(ctx, bufs, cms, tcpip.FullAddress{}) // If the buffer is full, we simply drop messages, just like // Linux. if err != nil && err != syserr.ErrWouldBlock { @@ -521,7 +648,7 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error // Add the dump_done_errno payload. m.Put(int64(0)) - _, notify, err := s.connection.Send([][]byte{m.Finalize()}, transport.ControlMessages{}, tcpip.FullAddress{}) + _, notify, err := s.connection.Send(ctx, [][]byte{m.Finalize()}, cms, tcpip.FullAddress{}) if err != nil && err != syserr.ErrWouldBlock { return err } @@ -533,47 +660,38 @@ func (s *Socket) sendResponse(ctx context.Context, ms *MessageSet) *syserr.Error return nil } -func (s *Socket) dumpErrorMesage(ctx context.Context, hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) *syserr.Error { +func dumpErrorMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet, err *syserr.Error) { m := ms.AddMessage(linux.NetlinkMessageHeader{ Type: linux.NLMSG_ERROR, }) - m.Put(linux.NetlinkErrorMessage{ Error: int32(-err.ToLinux().Number()), Header: hdr, }) - return nil +} +func dumpAckMesage(hdr linux.NetlinkMessageHeader, ms *MessageSet) { + m := ms.AddMessage(linux.NetlinkMessageHeader{ + Type: linux.NLMSG_ERROR, + }) + m.Put(linux.NetlinkErrorMessage{ + Error: 0, + Header: hdr, + }) } // processMessages handles each message in buf, passing it to the protocol // handler for final handling. -func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error { +func (s *socketOpsCommon) processMessages(ctx context.Context, buf []byte) *syserr.Error { for len(buf) > 0 { - if len(buf) < linux.NetlinkMessageHeaderSize { + msg, rest, ok := ParseMessage(buf) + if !ok { // Linux ignores messages that are too short. See // net/netlink/af_netlink.c:netlink_rcv_skb. break } - - var hdr linux.NetlinkMessageHeader - binary.Unmarshal(buf[:linux.NetlinkMessageHeaderSize], usermem.ByteOrder, &hdr) - - if hdr.Length < linux.NetlinkMessageHeaderSize || uint64(hdr.Length) > uint64(len(buf)) { - // Linux ignores malformed messages. See - // net/netlink/af_netlink.c:netlink_rcv_skb. - break - } - - // Data from this message. - data := buf[linux.NetlinkMessageHeaderSize:hdr.Length] - - // Advance to the next message. - next := alignUp(int(hdr.Length), linux.NLMSG_ALIGNTO) - if next >= len(buf)-1 { - next = len(buf) - 1 - } - buf = buf[next:] + buf = rest + hdr := msg.Header() // Ignore control messages. if hdr.Type < linux.NLMSG_MIN_TYPE { @@ -581,19 +699,10 @@ func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error } ms := NewMessageSet(s.portID, hdr.Seq) - var err *syserr.Error - // TODO(b/68877377): ACKs not supported yet. - if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK { - err = syserr.ErrNotSupported - } else { - - err = s.protocol.ProcessMessage(ctx, hdr, data, ms) - } - if err != nil { - ms = NewMessageSet(s.portID, hdr.Seq) - if err := s.dumpErrorMesage(ctx, hdr, ms, err); err != nil { - return err - } + if err := s.protocol.ProcessMessage(ctx, msg, ms); err != nil { + dumpErrorMesage(hdr, ms, err) + } else if hdr.Flags&linux.NLM_F_ACK == linux.NLM_F_ACK { + dumpAckMesage(hdr, ms) } if err := s.sendResponse(ctx, ms); err != nil { @@ -605,7 +714,7 @@ func (s *Socket) processMessages(ctx context.Context, buf []byte) *syserr.Error } // sendMsg is the core of message send, used for SendMsg and Write. -func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *socketOpsCommon) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, flags int, controlMessages socket.ControlMessages) (int, *syserr.Error) { dstPort := int32(0) if len(to) != 0 { @@ -652,7 +761,7 @@ func (s *Socket) sendMsg(ctx context.Context, src usermem.IOSequence, to []byte, } // SendMsg implements socket.Socket.SendMsg. -func (s *Socket) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { return s.sendMsg(t, src, to, flags, controlMessages) } @@ -663,11 +772,13 @@ func (s *Socket) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, } // State implements socket.Socket.State. -func (s *Socket) State() uint32 { +func (s *socketOpsCommon) State() uint32 { return s.ep.State() } // Type implements socket.Socket.Type. -func (s *Socket) Type() (family int, skType linux.SockType, protocol int) { +func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) { return linux.AF_NETLINK, s.skType, s.protocol.Protocol() } + +// LINT.ThenChange(./socket_vfs2.go) diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go new file mode 100644 index 000000000..a38d25da9 --- /dev/null +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -0,0 +1,152 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netlink + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/socket/unix" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" +) + +// SocketVFS2 is the base VFS2 socket type for netlink sockets. +// +// This implementation only supports userspace sending and receiving messages +// to/from the kernel. +// +// SocketVFS2 implements socket.SocketVFS2 and transport.Credentialer. +type SocketVFS2 struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + vfs.LockFD + + socketOpsCommon +} + +var _ socket.SocketVFS2 = (*SocketVFS2)(nil) +var _ transport.Credentialer = (*SocketVFS2)(nil) + +// NewVFS2 creates a new SocketVFS2. +func NewVFS2(t *kernel.Task, skType linux.SockType, protocol Protocol) (*SocketVFS2, *syserr.Error) { + // Datagram endpoint used to buffer kernel -> user messages. + ep := transport.NewConnectionless(t) + + // Bind the endpoint for good measure so we can connect to it. The + // bound address will never be exposed. + if err := ep.Bind(tcpip.FullAddress{Addr: "dummy"}, nil); err != nil { + ep.Close(t) + return nil, err + } + + // Create a connection from which the kernel can write messages. + connection, err := ep.(transport.BoundEndpoint).UnidirectionalConnect(t) + if err != nil { + ep.Close(t) + return nil, err + } + + fd := &SocketVFS2{ + socketOpsCommon: socketOpsCommon{ + ports: t.Kernel().NetlinkPorts(), + protocol: protocol, + skType: skType, + ep: ep, + connection: connection, + sendBufferSize: defaultSendBufferSize, + }, + } + fd.LockFD.Init(&vfs.FileLocks{}) + return fd, nil +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { + return s.socketOpsCommon.Readiness(mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + s.socketOpsCommon.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *SocketVFS2) EventUnregister(e *waiter.Entry) { + s.socketOpsCommon.EventUnregister(e) +} + +// Ioctl implements vfs.FileDescriptionImpl. +func (*SocketVFS2) Ioctl(context.Context, usermem.IO, arch.SyscallArguments) (uintptr, error) { + // TODO(b/68878065): no ioctls supported. + return 0, syserror.ENOTTY +} + +// PRead implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Read implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + if dst.NumBytes() == 0 { + return 0, nil + } + return dst.CopyOutFrom(ctx, &unix.EndpointReader{ + Endpoint: s.ep, + }) +} + +// PWrite implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Write implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + n, err := s.sendMsg(ctx, src, nil, 0, socket.ControlMessages{}) + return int64(n), err.ToError() +} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/socket/netlink/uevent/BUILD b/pkg/sentry/socket/netlink/uevent/BUILD new file mode 100644 index 000000000..b6434923c --- /dev/null +++ b/pkg/sentry/socket/netlink/uevent/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "uevent", + srcs = ["protocol.go"], + visibility = ["//pkg/sentry:internal"], + deps = [ + "//pkg/abi/linux", + "//pkg/context", + "//pkg/sentry/kernel", + "//pkg/sentry/socket/netlink", + "//pkg/syserr", + ], +) diff --git a/pkg/sentry/socket/netlink/uevent/protocol.go b/pkg/sentry/socket/netlink/uevent/protocol.go new file mode 100644 index 000000000..029ba21b5 --- /dev/null +++ b/pkg/sentry/socket/netlink/uevent/protocol.go @@ -0,0 +1,60 @@ +// Copyright 2019 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package uevent provides a NETLINK_KOBJECT_UEVENT socket protocol. +// +// NETLINK_KOBJECT_UEVENT sockets send udev-style device events. gVisor does +// not support any device events, so these sockets never send any messages. +package uevent + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket/netlink" + "gvisor.dev/gvisor/pkg/syserr" +) + +// Protocol implements netlink.Protocol. +// +// +stateify savable +type Protocol struct{} + +var _ netlink.Protocol = (*Protocol)(nil) + +// NewProtocol creates a NETLINK_KOBJECT_UEVENT netlink.Protocol. +func NewProtocol(t *kernel.Task) (netlink.Protocol, *syserr.Error) { + return &Protocol{}, nil +} + +// Protocol implements netlink.Protocol.Protocol. +func (p *Protocol) Protocol() int { + return linux.NETLINK_KOBJECT_UEVENT +} + +// CanSend implements netlink.Protocol.CanSend. +func (p *Protocol) CanSend() bool { + return false +} + +// ProcessMessage implements netlink.Protocol.ProcessMessage. +func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + // Silently ignore all messages. + return nil +} + +// init registers the NETLINK_KOBJECT_UEVENT provider. +func init() { + netlink.RegisterProvider(linux.NETLINK_KOBJECT_UEVENT, NewProtocol) +} diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD index e414d8055..1fb777a6c 100644 --- a/pkg/sentry/socket/netstack/BUILD +++ b/pkg/sentry/socket/netstack/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") package(licenses = ["notice"]) @@ -7,44 +7,52 @@ go_library( srcs = [ "device.go", "netstack.go", + "netstack_vfs2.go", "provider.go", + "provider_vfs2.go", "save_restore.go", "stack.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/netstack", visibility = [ "//pkg/sentry:internal", ], deps = [ "//pkg/abi/linux", + "//pkg/amutex", "//pkg/binary", + "//pkg/context", "//pkg/log", "//pkg/metric", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", + "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", "//pkg/sentry/kernel/time", - "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/netfilter", "//pkg/sentry/unimpl", - "//pkg/sentry/usermem", + "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/header", - "//pkg/tcpip/iptables", "//pkg/tcpip/network/ipv4", "//pkg/tcpip/network/ipv6", "//pkg/tcpip/stack", "//pkg/tcpip/transport/tcp", "//pkg/tcpip/transport/udp", + "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", + "//tools/go_marshal/primitive", + "@org_golang_x_sys//unix:go_default_library", ], ) diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 27c6692c4..e4846bc0b 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -26,29 +26,32 @@ package netstack import ( "bytes" + "fmt" "io" "math" "reflect" - "sync" + "sync/atomic" "syscall" "time" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/amutex" "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/metric" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/safemem" "gvisor.dev/gvisor/pkg/sentry/socket" "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/sentry/unimpl" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" @@ -57,12 +60,21 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" "gvisor.dev/gvisor/pkg/tcpip/transport/tcp" "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" ) func mustCreateMetric(name, description string) *tcpip.StatCounter { var cm tcpip.StatCounter - metric.MustRegisterCustomUint64Metric(name, false /* sync */, description, cm.Value) + metric.MustRegisterCustomUint64Metric(name, true /* cumulative */, false /* sync */, description, cm.Value) + return &cm +} + +func mustCreateGauge(name, description string) *tcpip.StatCounter { + var cm tcpip.StatCounter + metric.MustRegisterCustomUint64Metric(name, false /* cumulative */, false /* sync */, description, cm.Value) return &cm } @@ -138,19 +150,23 @@ var Metrics = tcpip.Stats{ }, }, IP: tcpip.IPStats{ - PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), - InvalidAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), - PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), - PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), - OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), - MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."), - MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."), + PacketsReceived: mustCreateMetric("/netstack/ip/packets_received", "Total number of IP packets received from the link layer in nic.DeliverNetworkPacket."), + InvalidDestinationAddressesReceived: mustCreateMetric("/netstack/ip/invalid_addresses_received", "Total number of IP packets received with an unknown or invalid destination address."), + InvalidSourceAddressesReceived: mustCreateMetric("/netstack/ip/invalid_source_addresses_received", "Total number of IP packets received with an unknown or invalid source address."), + PacketsDelivered: mustCreateMetric("/netstack/ip/packets_delivered", "Total number of incoming IP packets that are successfully delivered to the transport layer via HandlePacket."), + PacketsSent: mustCreateMetric("/netstack/ip/packets_sent", "Total number of IP packets sent via WritePacket."), + OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."), + MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."), + MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."), }, TCP: tcpip.TCPStats{ ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."), PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."), - CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in either ESTABLISHED or CLOSE-WAIT state now."), + CurrentEstablished: mustCreateGauge("/netstack/tcp/current_established", "Number of connections in ESTABLISHED state now."), + CurrentConnected: mustCreateGauge("/netstack/tcp/current_open", "Number of connections that are in connected state."), EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"), + EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "Number of times established TCP connections made a transition to CLOSED state."), + EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."), ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."), ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."), ListenOverflowSynCookieSent: mustCreateMetric("/netstack/tcp/listen_overflow_syn_cookie_sent", "Number of times a SYN cookie was sent."), @@ -178,6 +194,8 @@ var Metrics = tcpip.Stats{ MalformedPacketsReceived: mustCreateMetric("/netstack/udp/malformed_packets_received", "Number of incoming UDP datagrams dropped due to the UDP header being in a malformed state."), PacketsSent: mustCreateMetric("/netstack/udp/packets_sent", "Number of UDP datagrams sent."), PacketSendErrors: mustCreateMetric("/netstack/udp/packet_send_errors", "Number of UDP datagrams failed to be sent."), + ChecksumErrors: mustCreateMetric("/netstack/udp/checksum_errors", "Number of UDP datagrams dropped due to bad checksums."), + InvalidSourceAddress: mustCreateMetric("/netstack/udp/invalid_source", "Number of UDP datagrams dropped due to invalid source address."), }, } @@ -220,19 +238,29 @@ type commonEndpoint interface { // transport.Endpoint.SetSockOpt. SetSockOpt(interface{}) *tcpip.Error + // SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and + // transport.Endpoint.SetSockOptBool. + SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error + // SetSockOptInt implements tcpip.Endpoint.SetSockOptInt and // transport.Endpoint.SetSockOptInt. - SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error // GetSockOpt implements tcpip.Endpoint.GetSockOpt and // transport.Endpoint.GetSockOpt. GetSockOpt(interface{}) *tcpip.Error + // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and + // transport.Endpoint.GetSockOpt. + GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) + // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. - GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) + GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) } +// LINT.IfChange + // SocketOperations encapsulates all the state needed to represent a network stack // endpoint in the kernel context. // @@ -244,6 +272,14 @@ type SocketOperations struct { fsutil.FileNoFsync `state:"nosave"` fsutil.FileNoMMap `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` + + socketOpsCommon +} + +// socketOpsCommon contains the socket operations common to VFS1 and VFS2. +// +// +stateify savable +type socketOpsCommon struct { socket.SendReceiveTimeout *waiter.Queue @@ -252,14 +288,21 @@ type SocketOperations struct { skType linux.SockType protocol int + // readViewHasData is 1 iff readView has data to be read, 0 otherwise. + // Must be accessed using atomic operations. It must only be written + // with readMu held but can be read without holding readMu. The latter + // is required to avoid deadlocks in epoll Readiness checks. + readViewHasData uint32 + // readMu protects access to the below fields. readMu sync.Mutex `state:"nosave"` // readView contains the remaining payload from the last packet. readView buffer.View // readCM holds control message information for the last packet read // from Endpoint. - readCM tcpip.ControlMessages - sender tcpip.FullAddress + readCM tcpip.ControlMessages + sender tcpip.FullAddress + linkPacketInfo tcpip.LinkPacketInfo // sockOptTimestamp corresponds to SO_TIMESTAMP. When true, timestamps // of returned messages can be returned via control messages. When @@ -281,19 +324,21 @@ type SocketOperations struct { // New creates a new endpoint socket. func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) { if skType == linux.SOCK_STREAM { - if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil { + if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { return nil, syserr.TranslateNetstackError(err) } } dirent := socket.NewDirent(t, netstackDevice) - defer dirent.DecRef() + defer dirent.DecRef(t) return fs.NewFile(t, dirent, fs.FileFlags{Read: true, Write: true, NonSeekable: true}, &SocketOperations{ - Queue: queue, - family: family, - Endpoint: endpoint, - skType: skType, - protocol: protocol, + socketOpsCommon: socketOpsCommon{ + Queue: queue, + family: family, + Endpoint: endpoint, + skType: skType, + protocol: protocol, + }, }), nil } @@ -314,22 +359,15 @@ func bytesToIPAddress(addr []byte) tcpip.Address { // converts it to the FullAddress format. It supports AF_UNIX, AF_INET, // AF_INET6, and AF_PACKET addresses. // -// strict indicates whether addresses with the AF_UNSPEC family are accepted of not. -// // AddressAndFamily returns an address and its family. -func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, uint16, *syserr.Error) { +func AddressAndFamily(addr []byte) (tcpip.FullAddress, uint16, *syserr.Error) { // Make sure we have at least 2 bytes for the address family. if len(addr) < 2 { return tcpip.FullAddress{}, 0, syserr.ErrInvalidArgument } - family := usermem.ByteOrder.Uint16(addr) - if family != uint16(sfamily) && (!strict && family != linux.AF_UNSPEC) { - return tcpip.FullAddress{}, family, syserr.ErrAddressFamilyNotSupported - } - // Get the rest of the fields based on the address family. - switch family { + switch family := usermem.ByteOrder.Uint16(addr); family { case linux.AF_UNIX: path := addr[2:] if len(path) > linux.UnixPathMax { @@ -385,7 +423,7 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, return tcpip.FullAddress{}, family, syserr.ErrInvalidArgument } - // TODO(b/129292371): Return protocol too. + // TODO(gvisor.dev/issue/173): Return protocol too. return tcpip.FullAddress{ NIC: tcpip.NICID(a.InterfaceIndex), Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), @@ -399,33 +437,49 @@ func AddressAndFamily(sfamily int, addr []byte, strict bool) (tcpip.FullAddress, } } -func (s *SocketOperations) isPacketBased() bool { +func (s *socketOpsCommon) isPacketBased() bool { return s.skType == linux.SOCK_DGRAM || s.skType == linux.SOCK_SEQPACKET || s.skType == linux.SOCK_RDM || s.skType == linux.SOCK_RAW } // fetchReadView updates the readView field of the socket if it's currently // empty. It assumes that the socket is locked. -func (s *SocketOperations) fetchReadView() *syserr.Error { +// +// Precondition: s.readMu must be held. +func (s *socketOpsCommon) fetchReadView() *syserr.Error { if len(s.readView) > 0 { return nil } - s.readView = nil s.sender = tcpip.FullAddress{} + s.linkPacketInfo = tcpip.LinkPacketInfo{} - v, cms, err := s.Endpoint.Read(&s.sender) + var v buffer.View + var cms tcpip.ControlMessages + var err *tcpip.Error + + switch e := s.Endpoint.(type) { + // The ordering of these interfaces matters. The most specific + // interfaces must be specified before the more generic Endpoint + // interface. + case tcpip.PacketEndpoint: + v, cms, err = e.ReadPacket(&s.sender, &s.linkPacketInfo) + case tcpip.Endpoint: + v, cms, err = e.Read(&s.sender) + } if err != nil { + atomic.StoreUint32(&s.readViewHasData, 0) return syserr.TranslateNetstackError(err) } s.readView = v s.readCM = cms + atomic.StoreUint32(&s.readViewHasData, 1) return nil } // Release implements fs.FileOperations.Release. -func (s *SocketOperations) Release() { +func (s *socketOpsCommon) Release(context.Context) { s.Endpoint.Close() } @@ -520,11 +574,9 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO } if resCh != nil { - t := ctx.(*kernel.Task) - if err := t.Block(resCh); err != nil { - return 0, syserr.FromError(err).ToError() + if err := amutex.Block(ctx, resCh); err != nil { + return 0, err } - n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{}) } @@ -593,11 +645,9 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader } if resCh != nil { - t := ctx.(*kernel.Task) - if err := t.Block(resCh); err != nil { - return 0, syserr.FromError(err).ToError() + if err := amutex.Block(ctx, resCh); err != nil { + return 0, err } - n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{ Atomic: true, // See above. }) @@ -612,26 +662,54 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader } // Readiness returns a mask of ready events for socket s. -func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { +func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { r := s.Endpoint.Readiness(mask) // Check our cached value iff the caller asked for readability and the // endpoint itself is currently not readable. if (mask & ^r & waiter.EventIn) != 0 { - s.readMu.Lock() - if len(s.readView) > 0 { + if atomic.LoadUint32(&s.readViewHasData) == 1 { r |= waiter.EventIn } - s.readMu.Unlock() } return r } +func (s *socketOpsCommon) checkFamily(family uint16, exact bool) *syserr.Error { + if family == uint16(s.family) { + return nil + } + if !exact && family == linux.AF_INET && s.family == linux.AF_INET6 { + v, err := s.Endpoint.GetSockOptBool(tcpip.V6OnlyOption) + if err != nil { + return syserr.TranslateNetstackError(err) + } + if !v { + return nil + } + } + return syserr.ErrInvalidArgument +} + +// mapFamily maps the AF_INET ANY address to the IPv4-mapped IPv6 ANY if the +// receiver's family is AF_INET6. +// +// This is a hack to work around the fact that both IPv4 and IPv6 ANY are +// represented by the empty string. +// +// TODO(gvisor.dev/issue/1556): remove this function. +func (s *socketOpsCommon) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress { + if len(addr.Addr) == 0 && s.family == linux.AF_INET6 && family == linux.AF_INET { + addr.Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00" + } + return addr +} + // Connect implements the linux syscall connect(2) for sockets backed by // tpcip.Endpoint. -func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - addr, family, err := AddressAndFamily(s.family, sockaddr, false /* strict */) +func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { + addr, family, err := AddressAndFamily(sockaddr) if err != nil { return err } @@ -643,6 +721,12 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo } return syserr.TranslateNetstackError(err) } + + if err := s.checkFamily(family, false /* exact */); err != nil { + return err + } + addr = s.mapFamily(addr, family) + // Always return right away in the non-blocking case. if !blocking { return syserr.TranslateNetstackError(s.Endpoint.Connect(addr)) @@ -655,6 +739,14 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo defer s.EventUnregister(&e) if err := s.Endpoint.Connect(addr); err != tcpip.ErrConnectStarted && err != tcpip.ErrAlreadyConnecting { + if (s.family == unix.AF_INET || s.family == unix.AF_INET6) && s.skType == linux.SOCK_STREAM { + // TCP unlike UDP returns EADDRNOTAVAIL when it can't + // find an available local ephemeral port. + if err == tcpip.ErrNoPortAvailable { + return syserr.ErrAddressNotAvailable + } + } + return syserr.TranslateNetstackError(err) } @@ -670,10 +762,44 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo // Bind implements the linux syscall bind(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { - addr, _, err := AddressAndFamily(s.family, sockaddr, true /* strict */) - if err != nil { - return err +func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { + if len(sockaddr) < 2 { + return syserr.ErrInvalidArgument + } + + family := usermem.ByteOrder.Uint16(sockaddr) + var addr tcpip.FullAddress + + // Bind for AF_PACKET requires only family, protocol and ifindex. + // In function AddressAndFamily, we check the address length which is + // not needed for AF_PACKET bind. + if family == linux.AF_PACKET { + var a linux.SockAddrLink + if len(sockaddr) < sockAddrLinkSize { + return syserr.ErrInvalidArgument + } + binary.Unmarshal(sockaddr[:sockAddrLinkSize], usermem.ByteOrder, &a) + + if a.Protocol != uint16(s.protocol) { + return syserr.ErrInvalidArgument + } + + addr = tcpip.FullAddress{ + NIC: tcpip.NICID(a.InterfaceIndex), + Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]), + } + } else { + var err *syserr.Error + addr, family, err = AddressAndFamily(sockaddr) + if err != nil { + return err + } + + if err = s.checkFamily(family, true /* exact */); err != nil { + return err + } + + addr = s.mapFamily(addr, family) } // Issue the bind request to the endpoint. @@ -682,13 +808,13 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { // Listen implements the linux syscall listen(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { return syserr.TranslateNetstackError(s.Endpoint.Listen(backlog)) } // blockingAccept implements a blocking version of accept(2), that is, if no // connections are ready to be accept, it will block until one becomes ready. -func (s *SocketOperations) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) { +func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) { // Register for notifications. e, ch := waiter.NewChannelEntry(nil) s.EventRegister(&e, waiter.EventIn) @@ -728,7 +854,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, if err != nil { return 0, nil, 0, err } - defer ns.DecRef() + defer ns.DecRef(t) if flags&linux.SOCK_NONBLOCK != 0 { flags := ns.Flags() @@ -774,7 +900,7 @@ func ConvertShutdown(how int) (tcpip.ShutdownFlags, *syserr.Error) { // Shutdown implements the linux syscall shutdown(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { f, err := ConvertShutdown(how) if err != nil { return err @@ -786,7 +912,7 @@ func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is // implemented specifically for netstack.SocketOperations rather than // commonEndpoint. commonEndpoint should be extended to support socket @@ -796,25 +922,25 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptTimestamp { val = 1 } - return val, nil + return &val, nil } if level == linux.SOL_TCP && name == linux.TCP_INQ { if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - val := int32(0) + val := primitive.Int32(0) s.readMu.Lock() defer s.readMu.Unlock() if s.sockOptInq { val = 1 } - return val, nil + return &val, nil } if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { @@ -824,22 +950,30 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us return nil, syserr.ErrInvalidArgument } - info, err := netfilter.GetInfo(t, s.Endpoint, outPtr) + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr) if err != nil { return nil, err } - return info, nil + return &info, nil case linux.IPT_SO_GET_ENTRIES: if outLen < linux.SizeOfIPTGetEntries { return nil, syserr.ErrInvalidArgument } - entries, err := netfilter.GetEntries(t, s.Endpoint, outPtr, outLen) + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen) if err != nil { return nil, err } - return entries, nil + return &entries, nil } } @@ -849,7 +983,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us // GetSockOpt can be used to implement the linux syscall getsockopt(2) for // sockets backed by a commonEndpoint. -func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) { +func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch level { case linux.SOL_SOCKET: return getSockOptSocket(t, s, ep, family, skType, name, outLen) @@ -874,8 +1008,15 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, return nil, syserr.ErrProtocolNotAvailable } +func boolToInt32(v bool) int32 { + if v { + return 1 + } + return 0 +} + // getSockOptSocket implements GetSockOpt when level is SOL_SOCKET. -func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (marshal.Marshallable, *syserr.Error) { // TODO(b/124056281): Stop rejecting short optLen values in getsockopt. switch name { case linux.SO_ERROR: @@ -886,9 +1027,12 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family // Get the last error and convert it. err := ep.GetSockOpt(tcpip.ErrorOption{}) if err == nil { - return int32(0), nil + optP := primitive.Int32(0) + return &optP, nil } - return int32(syserr.TranslateNetstackError(err).ToLinux().Number()), nil + + optP := primitive.Int32(syserr.TranslateNetstackError(err).ToLinux().Number()) + return &optP, nil case linux.SO_PEERCRED: if family != linux.AF_UNIX || outLen < syscall.SizeofUcred { @@ -896,23 +1040,25 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family } tcred := t.Credentials() - return syscall.Ucred{ - Pid: int32(t.ThreadGroup().ID()), - Uid: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), - Gid: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), - }, nil + creds := linux.ControlMessageCredentials{ + PID: int32(t.ThreadGroup().ID()), + UID: uint32(tcred.EffectiveKUID.In(tcred.UserNamespace).OrOverflow()), + GID: uint32(tcred.EffectiveKGID.In(tcred.UserNamespace).OrOverflow()), + } + return &creds, nil case linux.SO_PASSCRED: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.PasscredOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.PasscredOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_SNDBUF: if outLen < sizeOfInt32 { @@ -928,7 +1074,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_RCVBUF: if outLen < sizeOfInt32 { @@ -944,74 +1091,93 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family size = math.MaxInt32 } - return int32(size), nil + sizeP := primitive.Int32(size) + return &sizeP, nil case linux.SO_REUSEADDR: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.ReuseAddressOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_REUSEPORT: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.ReusePortOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.ReusePortOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_BINDTODEVICE: var v tcpip.BindToDeviceOption if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - if len(v) == 0 { - return []byte{}, nil + if v == 0 { + var b primitive.ByteSlice + return &b, nil } if outLen < linux.IFNAMSIZ { return nil, syserr.ErrInvalidArgument } - return append([]byte(v), 0), nil + s := t.NetworkContext() + if s == nil { + return nil, syserr.ErrNoDevice + } + nic, ok := s.Interfaces()[int32(v)] + if !ok { + // The NICID no longer indicates a valid interface, probably because that + // interface was removed. + return nil, syserr.ErrUnknownDevice + } + + name := primitive.ByteSlice(append([]byte(nic.Name), 0)) + return &name, nil case linux.SO_BROADCAST: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.BroadcastOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.BroadcastOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_KEEPALIVE: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.KeepaliveEnabledOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.SO_LINGER: if outLen < linux.SizeOfLinger { return nil, syserr.ErrInvalidArgument } - return linux.Linger{}, nil + + linger := linux.Linger{} + return &linger, nil case linux.SO_SNDTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1019,7 +1185,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.SendTimeout()), nil + sendTimeout := linux.NsecToTimeval(s.SendTimeout()) + return &sendTimeout, nil case linux.SO_RCVTIMEO: // TODO(igudger): Linux allows shorter lengths for partial results. @@ -1027,7 +1194,8 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.ErrInvalidArgument } - return linux.NsecToTimeval(s.RecvTimeout()), nil + recvTimeout := linux.NsecToTimeval(s.RecvTimeout()) + return &recvTimeout, nil case linux.SO_OOBINLINE: if outLen < sizeOfInt32 { @@ -1039,7 +1207,20 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil + + case linux.SO_NO_CHECK: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.NoChecksumOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil default: socket.GetSockOptEmitUnimplementedEvent(t, name) @@ -1048,58 +1229,58 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family } // getSockOptTCP implements GetSockOpt when level is SOL_TCP. -func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.TCP_NODELAY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - v, err := ep.GetSockOptInt(tcpip.DelayOption) + v, err := ep.GetSockOptBool(tcpip.DelayOption) if err != nil { return nil, syserr.TranslateNetstackError(err) } - if v == 0 { - return int32(1), nil - } - return int32(0), nil + vP := primitive.Int32(boolToInt32(!v)) + return &vP, nil case linux.TCP_CORK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.CorkOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.CorkOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_QUICKACK: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.QuickAckOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.QuickAckOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.TCP_MAXSEG: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.MaxSegOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.MaxSegOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.TCP_KEEPIDLE: if outLen < sizeOfInt32 { @@ -1110,8 +1291,8 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } - - return int32(time.Duration(v) / time.Second), nil + keepAliveIdle := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveIdle, nil case linux.TCP_KEEPINTVL: if outLen < sizeOfInt32 { @@ -1122,8 +1303,32 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa if err := ep.GetSockOpt(&v); err != nil { return nil, syserr.TranslateNetstackError(err) } + keepAliveInterval := primitive.Int32(time.Duration(v) / time.Second) + return &keepAliveInterval, nil + + case linux.TCP_KEEPCNT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptInt(tcpip.KeepaliveCountOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + vP := primitive.Int32(v) + return &vP, nil - return int32(time.Duration(v) / time.Second), nil + case linux.TCP_USER_TIMEOUT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPUserTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + tcpUserTimeout := primitive.Int32(time.Duration(v) / time.Millisecond) + return &tcpUserTimeout, nil case linux.TCP_INFO: var v tcpip.TCPInfoOption @@ -1136,12 +1341,13 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa info := linux.TCPInfo{} // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &info) - if len(ib) > outLen { - ib = ib[:outLen] + buf := t.CopyScratchBuffer(info.SizeBytes()) + info.MarshalUnsafe(buf) + if len(buf) > outLen { + buf = buf[:outLen] } - - return ib, nil + bufP := primitive.ByteSlice(buf) + return &bufP, nil case linux.TCP_CC_INFO, linux.TCP_NOTSENT_LOWAT, @@ -1171,8 +1377,59 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } b := make([]byte, toCopy) copy(b, v) - return b, nil + bP := primitive.ByteSlice(b) + return &bP, nil + + case linux.TCP_LINGER2: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPLingerTimeoutOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + lingerTimeout := primitive.Int32(time.Duration(v) / time.Second) + return &lingerTimeout, nil + + case linux.TCP_DEFER_ACCEPT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.TCPDeferAcceptOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + tcpDeferAccept := primitive.Int32(time.Duration(v) / time.Second) + return &tcpDeferAccept, nil + + case linux.TCP_SYNCNT: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptInt(tcpip.TCPSynCountOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + vP := primitive.Int32(v) + return &vP, nil + + case linux.TCP_WINDOW_CLAMP: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptInt(tcpip.TCPWindowClampOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + vP := primitive.Int32(v) + return &vP, nil default: emitUnimplementedEventTCP(t, name) } @@ -1180,19 +1437,20 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa } // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. -func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interface{}, *syserr.Error) { +func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.V6OnlyOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.V6OnlyOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IPV6_PATHMTU: t.Kernel().EmitUnimplementedEvent(t) @@ -1200,21 +1458,41 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf case linux.IPV6_TCLASS: // Length handling for parity with Linux. if outLen == 0 { - return make([]byte, 0), nil + var b primitive.ByteSlice + return &b, nil } - var v tcpip.IPv6TrafficClassOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - uintv := uint32(v) + uintv := primitive.Uint32(v) // Linux truncates the output binary to outLen. - ib := binary.Marshal(nil, usermem.ByteOrder, &uintv) + ib := t.CopyScratchBuffer(uintv.SizeBytes()) + uintv.MarshalUnsafe(ib) // Handle cases where outLen is lesser than sizeOfInt32. if len(ib) > outLen { ib = ib[:outLen] } - return ib, nil + ibP := primitive.ByteSlice(ib) + return &ibP, nil + + case linux.IPV6_RECVTCLASS: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil + + case linux.SO_ORIGINAL_DST: + // TODO(gvisor.dev/issue/170): ip6tables. + return nil, syserr.ErrInvalidArgument default: emitUnimplementedEventIPv6(t, name) @@ -1223,36 +1501,38 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf } // getSockOptIP implements GetSockOpt when level is SOL_IP. -func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (interface{}, *syserr.Error) { +func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IP_TTL: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.TTLOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.TTLOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } // Fill in the default value, if needed. - if v == 0 { - v = DefaultTTL + vP := primitive.Int32(v) + if vP == 0 { + vP = DefaultTTL } - return int32(v), nil + return &vP, nil case linux.IP_MULTICAST_TTL: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.MulticastTTLOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.MulticastTTLOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - return int32(v), nil + vP := primitive.Int32(v) + return &vP, nil case linux.IP_MULTICAST_IF: if outLen < len(linux.InetAddr{}) { @@ -1266,36 +1546,76 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress{Addr: v.InterfaceAddr}) - return a.(*linux.SockAddrInet).Addr, nil + return &a.(*linux.SockAddrInet).Addr, nil case linux.IP_MULTICAST_LOOP: if outLen < sizeOfInt32 { return nil, syserr.ErrInvalidArgument } - var v tcpip.MulticastLoopOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } - if v { - return int32(1), nil - } - return int32(0), nil + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil case linux.IP_TOS: // Length handling for parity with Linux. if outLen == 0 { - return []byte(nil), nil + var b primitive.ByteSlice + return &b, nil } - var v tcpip.IPv4TOSOption - if err := ep.GetSockOpt(&v); err != nil { + v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption) + if err != nil { return nil, syserr.TranslateNetstackError(err) } if outLen < sizeOfInt32 { - return uint8(v), nil + vP := primitive.Uint8(v) + return &vP, nil + } + vP := primitive.Int32(v) + return &vP, nil + + case linux.IP_RECVTOS: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument } - return int32(v), nil + + v, err := ep.GetSockOptBool(tcpip.ReceiveTOSOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil + + case linux.IP_PKTINFO: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil + + case linux.SO_ORIGINAL_DST: + if outLen < int(binary.Size(linux.SockAddrInet{})) { + return nil, syserr.ErrInvalidArgument + } + + var v tcpip.OriginalDestinationOption + if err := ep.GetSockOpt(&v); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v)) + return a.(*linux.SockAddrInet), nil default: emitUnimplementedEventIP(t, name) @@ -1330,12 +1650,32 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa return nil } + if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { + switch name { + case linux.IPT_SO_SET_REPLACE: + if len(optVal) < linux.SizeOfIPTReplace { + return syserr.ErrInvalidArgument + } + + stack := inet.StackFromContext(t) + if stack == nil { + return syserr.ErrNoDevice + } + // Stack must be a netstack stack. + return netfilter.SetEntries(stack.(*Stack).Stack, optVal) + + case linux.IPT_SO_SET_ADD_COUNTERS: + // TODO(gvisor.dev/issue/170): Counter support. + return nil + } + } + return SetSockOpt(t, s, s.Endpoint, level, name, optVal) } // SetSockOpt can be used to implement the linux syscall setsockopt(2) for // sockets backed by a commonEndpoint. -func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error { +func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error { switch level { case linux.SOL_SOCKET: return setSockOptSocket(t, s, ep, name, optVal) @@ -1362,7 +1702,7 @@ func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, n } // setSockOptSocket implements SetSockOpt when level is SOL_SOCKET. -func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int, optVal []byte) *syserr.Error { +func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { switch name { case linux.SO_SNDBUF: if len(optVal) < sizeOfInt32 { @@ -1386,7 +1726,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReuseAddressOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0)) case linux.SO_REUSEPORT: if len(optVal) < sizeOfInt32 { @@ -1394,14 +1734,27 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0)) case linux.SO_BINDTODEVICE: n := bytes.IndexByte(optVal, 0) if n == -1 { n = len(optVal) } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(optVal[:n]))) + name := string(optVal[:n]) + if name == "" { + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0))) + } + s := t.NetworkContext() + if s == nil { + return syserr.ErrNoDevice + } + for nicID, nic := range s.Interfaces() { + if nic.Name == name { + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID))) + } + } + return syserr.ErrUnknownDevice case linux.SO_BROADCAST: if len(optVal) < sizeOfInt32 { @@ -1409,7 +1762,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BroadcastOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.BroadcastOption, v != 0)) case linux.SO_PASSCRED: if len(optVal) < sizeOfInt32 { @@ -1417,7 +1770,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0)) case linux.SO_KEEPALIVE: if len(optVal) < sizeOfInt32 { @@ -1425,7 +1778,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveEnabledOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0)) case linux.SO_SNDTIMEO: if len(optVal) < linux.SizeOfTimeval { @@ -1466,6 +1819,14 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.OutOfBandInlineOption(v))) + case linux.SO_NO_CHECK: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.NoChecksumOption, v != 0)) + case linux.SO_LINGER: if len(optVal) < linux.SizeOfLinger { return syserr.ErrInvalidArgument @@ -1480,6 +1841,11 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i return nil + case linux.SO_DETACH_FILTER: + // optval is ignored. + var v tcpip.SocketDetachFilterOption + return syserr.TranslateNetstackError(ep.SetSockOpt(v)) + default: socket.SetSockOptEmitUnimplementedEvent(t, name) } @@ -1497,11 +1863,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - var o int - if v == 0 { - o = 1 - } - return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o)) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0)) case linux.TCP_CORK: if len(optVal) < sizeOfInt32 { @@ -1509,7 +1871,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.CorkOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0)) case linux.TCP_QUICKACK: if len(optVal) < sizeOfInt32 { @@ -1517,7 +1879,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0)) case linux.TCP_MAXSEG: if len(optVal) < sizeOfInt32 { @@ -1525,7 +1887,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MaxSegOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MaxSegOption, int(v))) case linux.TCP_KEEPIDLE: if len(optVal) < sizeOfInt32 { @@ -1549,6 +1911,28 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v)))) + case linux.TCP_KEEPCNT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + if v < 1 || v > linux.MAX_TCP_KEEPCNT { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.KeepaliveCountOption, int(v))) + + case linux.TCP_USER_TIMEOUT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < 0 { + return syserr.ErrInvalidArgument + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v)))) + case linux.TCP_CONGESTION: v := tcpip.CongestionControlOption(optVal) if err := ep.SetSockOpt(v); err != nil { @@ -1556,6 +1940,40 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } return nil + case linux.TCP_LINGER2: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + + v := usermem.ByteOrder.Uint32(optVal) + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v)))) + + case linux.TCP_DEFER_ACCEPT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := int32(usermem.ByteOrder.Uint32(optVal)) + if v < 0 { + v = 0 + } + return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v)))) + + case linux.TCP_SYNCNT: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := usermem.ByteOrder.Uint32(optVal) + + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPSynCountOption, int(v))) + + case linux.TCP_WINDOW_CLAMP: + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + v := usermem.ByteOrder.Uint32(optVal) + + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TCPWindowClampOption, int(v))) + case linux.TCP_REPAIR_OPTIONS: t.Kernel().EmitUnimplementedEvent(t) @@ -1576,13 +1994,14 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) } v := usermem.ByteOrder.Uint32(optVal) - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.V6OnlyOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.V6OnlyOption, v != 0)) case linux.IPV6_ADD_MEMBERSHIP, linux.IPV6_DROP_MEMBERSHIP, linux.IPV6_IPSEC_POLICY, linux.IPV6_JOIN_ANYCAST, linux.IPV6_LEAVE_ANYCAST, + // TODO(b/148887420): Add support for IPV6_PKTINFO. linux.IPV6_PKTINFO, linux.IPV6_ROUTER_ALERT, linux.IPV6_XFRM_POLICY, @@ -1606,7 +2025,15 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) if v == -1 { v = 0 } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, int(v))) + + case linux.IPV6_RECVTCLASS: + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0)) default: emitUnimplementedEventIPv6(t, name) @@ -1683,7 +2110,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s if v < 0 || v > 255 { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastTTLOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MulticastTTLOption, int(v))) case linux.IP_ADD_MEMBERSHIP: req, err := copyInMulticastRequest(optVal, false /* allowAddr */) @@ -1730,9 +2157,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s return err } - return syserr.TranslateNetstackError(ep.SetSockOpt( - tcpip.MulticastLoopOption(v != 0), - )) + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0)) case linux.MCAST_JOIN_GROUP: // FIXME(b/124219304): Implement MCAST_JOIN_GROUP. @@ -1751,7 +2176,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s } else if v < 1 || v > 255 { return syserr.ErrInvalidArgument } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TTLOption, int(v))) case linux.IP_TOS: if len(optVal) == 0 { @@ -1761,7 +2186,34 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s if err != nil { return err } - return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v))) + return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv4TOSOption, int(v))) + + case linux.IP_RECVTOS: + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0)) + + case linux.IP_PKTINFO: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0)) + + case linux.IP_HDRINCL: + if len(optVal) == 0 { + return nil + } + v, err := parseIntOrChar(optVal) + if err != nil { + return err + } + return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0)) case linux.IP_ADD_SOURCE_MEMBERSHIP, linux.IP_BIND_ADDRESS_NO_PORT, @@ -1769,7 +2221,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_CHECKSUM, linux.IP_DROP_SOURCE_MEMBERSHIP, linux.IP_FREEBIND, - linux.IP_HDRINCL, linux.IP_IPSEC_POLICY, linux.IP_MINTTL, linux.IP_MSFILTER, @@ -1778,12 +2229,10 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s linux.IP_NODEFRAG, linux.IP_OPTIONS, linux.IP_PASSSEC, - linux.IP_PKTINFO, linux.IP_RECVERR, linux.IP_RECVFRAGSIZE, linux.IP_RECVOPTS, linux.IP_RECVORIGDSTADDR, - linux.IP_RECVTOS, linux.IP_RECVTTL, linux.IP_RETOPTS, linux.IP_TRANSPARENT, @@ -1811,30 +2260,20 @@ func emitUnimplementedEventTCP(t *kernel.Task, name int) { switch name { case linux.TCP_CONGESTION, linux.TCP_CORK, - linux.TCP_DEFER_ACCEPT, linux.TCP_FASTOPEN, linux.TCP_FASTOPEN_CONNECT, linux.TCP_FASTOPEN_KEY, linux.TCP_FASTOPEN_NO_COOKIE, - linux.TCP_KEEPCNT, - linux.TCP_KEEPIDLE, - linux.TCP_KEEPINTVL, - linux.TCP_LINGER2, - linux.TCP_MAXSEG, linux.TCP_QUEUE_SEQ, - linux.TCP_QUICKACK, linux.TCP_REPAIR, linux.TCP_REPAIR_QUEUE, linux.TCP_REPAIR_WINDOW, linux.TCP_SAVED_SYN, linux.TCP_SAVE_SYN, - linux.TCP_SYNCNT, linux.TCP_THIN_DUPACK, linux.TCP_THIN_LINEAR_TIMEOUTS, linux.TCP_TIMESTAMP, - linux.TCP_ULP, - linux.TCP_USER_TIMEOUT, - linux.TCP_WINDOW_CLAMP: + linux.TCP_ULP: t.Kernel().EmitUnimplementedEvent(t) } @@ -1876,7 +2315,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) { linux.IPV6_RECVPATHMTU, linux.IPV6_RECVPKTINFO, linux.IPV6_RECVRTHDR, - linux.IPV6_RECVTCLASS, linux.IPV6_RTHDR, linux.IPV6_RTHDRDSTOPTS, linux.IPV6_TCLASS, @@ -1981,8 +2419,8 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) case linux.AF_INET6: var out linux.SockAddrInet6 - if len(addr.Addr) == 4 { - // Copy address is v4-mapped format. + if len(addr.Addr) == header.IPv4AddressSize { + // Copy address in v4-mapped format. copy(out.Addr[12:], addr.Addr) out.Addr[10] = 0xff out.Addr[11] = 0xff @@ -1997,7 +2435,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) return &out, uint32(sockAddrInet6Size) case linux.AF_PACKET: - // TODO(b/129292371): Return protocol too. + // TODO(gvisor.dev/issue/173): Return protocol too. var out linux.SockAddrLink out.Family = linux.AF_PACKET out.InterfaceIndex = int32(addr.NIC) @@ -2012,7 +2450,7 @@ func ConvertAddress(family int, addr tcpip.FullAddress) (linux.SockAddr, uint32) // GetSockName implements the linux syscall getsockname(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -2024,7 +2462,7 @@ func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.Endpoint.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -2039,16 +2477,21 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, // caller. // // Precondition: s.readMu must be locked. -func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) { +func (s *socketOpsCommon) coalescingRead(ctx context.Context, dst usermem.IOSequence, discard bool) (int, *syserr.Error) { var err *syserr.Error var copied int // Copy as many views as possible into the user-provided buffer. - for dst.NumBytes() != 0 { + for { + // Always do at least one fetchReadView, even if the number of bytes to + // read is 0. err = s.fetchReadView() if err != nil { break } + if dst.NumBytes() == 0 { + break + } var n int var e error @@ -2066,6 +2509,10 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq } copied += n s.readView.TrimFront(n) + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) + } + dst = dst.DropFirst(n) if e != nil { err = syserr.FromError(e) @@ -2082,7 +2529,7 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq return 0, err } -func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { +func (s *socketOpsCommon) fillCmsgInq(cmsg *socket.ControlMessages) { if !s.sockOptInq { return } @@ -2094,10 +2541,27 @@ func (s *SocketOperations) fillCmsgInq(cmsg *socket.ControlMessages) { cmsg.IP.Inq = int32(len(s.readView) + rcvBufUsed) } +func toLinuxPacketType(pktType tcpip.PacketType) uint8 { + switch pktType { + case tcpip.PacketHost: + return linux.PACKET_HOST + case tcpip.PacketOtherHost: + return linux.PACKET_OTHERHOST + case tcpip.PacketOutgoing: + return linux.PACKET_OUTGOING + case tcpip.PacketBroadcast: + return linux.PACKET_BROADCAST + case tcpip.PacketMulticast: + return linux.PACKET_MULTICAST + default: + panic(fmt.Sprintf("unknown packet type: %d", pktType)) + } +} + // nonBlockingRead issues a non-blocking read. // // TODO(b/78348848): Support timestamps for stream sockets. -func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { +func (s *socketOpsCommon) nonBlockingRead(ctx context.Context, dst usermem.IOSequence, peek, trunc, senderRequested bool) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { isPacket := s.isPacketBased() // Fast path for regular reads from stream (e.g., TCP) endpoints. Note @@ -2112,9 +2576,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe // caller-supplied buffer. s.readMu.Lock() n, err := s.coalescingRead(ctx, dst, trunc) - s.readMu.Unlock() cmsg := s.controlMessages() s.fillCmsgInq(&cmsg) + s.readMu.Unlock() return n, 0, nil, 0, cmsg, err } @@ -2149,6 +2613,11 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe var addrLen uint32 if isPacket && senderRequested { addr, addrLen = ConvertAddress(s.family, s.sender) + switch v := addr.(type) { + case *linux.SockAddrLink: + v.Protocol = htons(uint16(s.linkPacketInfo.Protocol)) + v.PacketType = toLinuxPacketType(s.linkPacketInfo.PktType) + } } if peek { @@ -2188,6 +2657,10 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe s.readView.TrimFront(int(n)) } + if len(s.readView) == 0 { + atomic.StoreUint32(&s.readViewHasData, 0) + } + var flags int if msgLen > int(n) { flags |= linux.MSG_TRUNC @@ -2202,15 +2675,26 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe return n, flags, addr, addrLen, cmsg, syserr.FromError(err) } -func (s *SocketOperations) controlMessages() socket.ControlMessages { - return socket.ControlMessages{IP: tcpip.ControlMessages{HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, Timestamp: s.readCM.Timestamp}} +func (s *socketOpsCommon) controlMessages() socket.ControlMessages { + return socket.ControlMessages{ + IP: tcpip.ControlMessages{ + HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp, + Timestamp: s.readCM.Timestamp, + HasTOS: s.readCM.HasTOS, + TOS: s.readCM.TOS, + HasTClass: s.readCM.HasTClass, + TClass: s.readCM.TClass, + HasIPPacketInfo: s.readCM.HasIPPacketInfo, + PacketInfo: s.readCM.PacketInfo, + }, + } } // updateTimestamp sets the timestamp for SIOCGSTAMP. It should be called after // successfully writing packet data out to userspace. // // Precondition: s.readMu must be locked. -func (s *SocketOperations) updateTimestamp() { +func (s *socketOpsCommon) updateTimestamp() { // Save the SIOCGSTAMP timestamp only if SO_TIMESTAMP is disabled. if !s.sockOptTimestamp { s.timestampValid = true @@ -2220,7 +2704,7 @@ func (s *SocketOperations) updateTimestamp() { // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -2288,7 +2772,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags // SendMsg implements the linux syscall sendmsg(2) for sockets backed by // tcpip.Endpoint. -func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { // Reject Unix control messages. if !controlMessages.Unix.Empty() { return 0, syserr.ErrInvalidArgument @@ -2296,10 +2780,14 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] var addr *tcpip.FullAddress if len(to) > 0 { - addrBuf, _, err := AddressAndFamily(s.family, to, true /* strict */) + addrBuf, family, err := AddressAndFamily(to) if err != nil { return 0, err } + if err := s.checkFamily(family, false /* exact */); err != nil { + return 0, err + } + addrBuf = s.mapFamily(addrBuf, family) addr = &addrBuf } @@ -2360,11 +2848,20 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] // Ioctl implements fs.FileOperations.Ioctl. func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return s.socketOpsCommon.ioctl(ctx, io, args) +} + +func (s *socketOpsCommon) ioctl(ctx context.Context, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("ioctl(2) may only be called from a task goroutine") + } + // SIOCGSTAMP is implemented by netstack rather than all commonEndpoint // sockets. // TODO(b/78348848): Add a commonEndpoint method to support SIOCGSTAMP. switch args[1].Int() { - case syscall.SIOCGSTAMP: + case linux.SIOCGSTAMP: s.readMu.Lock() defer s.readMu.Unlock() if !s.timestampValid { @@ -2372,9 +2869,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, } tv := linux.NsecToTimeval(s.timestampNS) - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &tv, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := tv.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCINQ: @@ -2384,16 +2879,17 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, } // Add bytes removed from the endpoint but not yet sent to the caller. + s.readMu.Lock() v += len(s.readView) + s.readMu.Unlock() if v > math.MaxInt32 { v = math.MaxInt32 } - // Copy result to user-space. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + // Copy result to userspace. + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err } @@ -2402,52 +2898,49 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // Ioctl performs a socket ioctl. func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { + t := kernel.TaskFromContext(ctx) + if t == nil { + panic("ioctl(2) may only be called from a task goroutine") + } + switch arg := int(args[1].Int()); arg { - case syscall.SIOCGIFFLAGS, - syscall.SIOCGIFADDR, - syscall.SIOCGIFBRDADDR, - syscall.SIOCGIFDSTADDR, - syscall.SIOCGIFHWADDR, - syscall.SIOCGIFINDEX, - syscall.SIOCGIFMAP, - syscall.SIOCGIFMETRIC, - syscall.SIOCGIFMTU, - syscall.SIOCGIFNAME, - syscall.SIOCGIFNETMASK, - syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFFLAGS, + linux.SIOCGIFADDR, + linux.SIOCGIFBRDADDR, + linux.SIOCGIFDSTADDR, + linux.SIOCGIFHWADDR, + linux.SIOCGIFINDEX, + linux.SIOCGIFMAP, + linux.SIOCGIFMETRIC, + linux.SIOCGIFMTU, + linux.SIOCGIFNAME, + linux.SIOCGIFNETMASK, + linux.SIOCGIFTXQLEN, + linux.SIOCETHTOOL: var ifr linux.IFReq - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifr.CopyIn(t, args[2].Pointer()); err != nil { return 0, err } if err := interfaceIoctl(ctx, io, arg, &ifr); err != nil { return 0, err.ToError() } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), &ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }) + _, err := ifr.CopyOut(t, args[2].Pointer()) return 0, err - case syscall.SIOCGIFCONF: + case linux.SIOCGIFCONF: // Return a list of interface addresses or the buffer size // necessary to hold the list. var ifc linux.IFConf - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifc.CopyIn(t, args[2].Pointer()); err != nil { return 0, err } - if err := ifconfIoctl(ctx, io, &ifc); err != nil { + if err := ifconfIoctl(ctx, t, io, &ifc); err != nil { return 0, err } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{ - AddressSpaceActive: true, - }) - + _, err := ifc.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCINQ: @@ -2459,10 +2952,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc if v > math.MaxInt32 { v = math.MaxInt32 } - // Copy result to user-space. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + // Copy result to userspace. + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err case linux.TIOCOUTQ: @@ -2475,10 +2967,9 @@ func Ioctl(ctx context.Context, ep commonEndpoint, io usermem.IO, args arch.Sysc v = math.MaxInt32 } - // Copy result to user-space. - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), int32(v), usermem.IOOpts{ - AddressSpaceActive: true, - }) + // Copy result to userspace. + vP := primitive.Int32(v) + _, err := vP.CopyOut(t, args[2].Pointer()) return 0, err case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG: @@ -2504,7 +2995,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // SIOCGIFNAME uses ifr.ifr_ifindex rather than ifr.ifr_name to // identify a device. - if arg == syscall.SIOCGIFNAME { + if arg == linux.SIOCGIFNAME { // Gets the name of the interface given the interface index // stored in ifr_ifindex. index = int32(usermem.ByteOrder.Uint32(ifr.Data[:4])) @@ -2527,21 +3018,28 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } switch arg { - case syscall.SIOCGIFINDEX: + case linux.SIOCGIFINDEX: // Copy out the index to the data. usermem.ByteOrder.PutUint32(ifr.Data[:], uint32(index)) - case syscall.SIOCGIFHWADDR: + case linux.SIOCGIFHWADDR: // Copy the hardware address out. - ifr.Data[0] = 6 // IEEE802.2 arp type. - ifr.Data[1] = 0 + // + // Refer: https://linux.die.net/man/7/netdevice + // SIOCGIFHWADDR, SIOCSIFHWADDR + // + // Get or set the hardware address of a device using + // ifr_hwaddr. The hardware address is specified in a struct + // sockaddr. sa_family contains the ARPHRD_* device type, + // sa_data the L2 hardware address starting from byte 0. Setting + // the hardware address is a privileged operation. + usermem.ByteOrder.PutUint16(ifr.Data[:], iface.DeviceType) n := copy(ifr.Data[2:], iface.Addr) for i := 2 + n; i < len(ifr.Data); i++ { ifr.Data[i] = 0 // Clear padding. } - usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(n)) - case syscall.SIOCGIFFLAGS: + case linux.SIOCGIFFLAGS: f, err := interfaceStatusFlags(stack, iface.Name) if err != nil { return err @@ -2550,7 +3048,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe // matches Linux behavior. usermem.ByteOrder.PutUint16(ifr.Data[:2], uint16(f)) - case syscall.SIOCGIFADDR: + case linux.SIOCGIFADDR: // Copy the IPv4 address out. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2561,32 +3059,32 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } - case syscall.SIOCGIFMETRIC: + case linux.SIOCGIFMETRIC: // Gets the metric of the device. As per netdevice(7), this // always just sets ifr_metric to 0. usermem.ByteOrder.PutUint32(ifr.Data[:4], 0) - case syscall.SIOCGIFMTU: + case linux.SIOCGIFMTU: // Gets the MTU of the device. usermem.ByteOrder.PutUint32(ifr.Data[:4], iface.MTU) - case syscall.SIOCGIFMAP: + case linux.SIOCGIFMAP: // Gets the hardware parameters of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFTXQLEN: + case linux.SIOCGIFTXQLEN: // Gets the transmit queue length of the device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFDSTADDR: + case linux.SIOCGIFDSTADDR: // Gets the destination address of a point-to-point device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFBRDADDR: + case linux.SIOCGIFBRDADDR: // Gets the broadcast address of a device. // TODO(gvisor.dev/issue/505): Implement. - case syscall.SIOCGIFNETMASK: + case linux.SIOCGIFNETMASK: // Gets the network mask of a device. for _, addr := range stack.InterfaceAddrs()[index] { // This ioctl is only compatible with AF_INET addresses. @@ -2603,6 +3101,14 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe break } + case linux.SIOCETHTOOL: + // Stubbed out for now, Ideally we should implement the required + // sub-commands for ETHTOOL + // + // See: + // https://github.com/torvalds/linux/blob/aa0c9086b40c17a7ad94425b3b70dd1fdd7497bf/net/core/dev_ioctl.c + return syserr.ErrEndpointOperation + default: // Not a valid call. return syserr.ErrInvalidArgument @@ -2612,7 +3118,7 @@ func interfaceIoctl(ctx context.Context, io usermem.IO, arg int, ifr *linux.IFRe } // ifconfIoctl populates a struct ifconf for the SIOCGIFCONF ioctl. -func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { +func ifconfIoctl(ctx context.Context, t *kernel.Task, io usermem.IO, ifc *linux.IFConf) error { // If Ptr is NULL, return the necessary buffer size via Len. // Otherwise, write up to Len bytes starting at Ptr containing ifreq // structs. @@ -2649,9 +3155,7 @@ func ifconfIoctl(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { // Copy the ifr to userspace. dst := uintptr(ifc.Ptr) + uintptr(ifc.Len) ifc.Len += int32(linux.SizeOfIFReq) - if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { + if _, err := ifr.CopyOut(t, usermem.Addr(dst)); err != nil { return err } } @@ -2697,7 +3201,7 @@ func nicStateFlagsToLinux(f stack.NICStateFlags) uint32 { // State implements socket.Socket.State. State translates the internal state // returned by netstack to values defined by Linux. -func (s *SocketOperations) State() uint32 { +func (s *socketOpsCommon) State() uint32 { if s.family != linux.AF_INET && s.family != linux.AF_INET6 { // States not implemented for this socket's family. return 0 @@ -2757,6 +3261,8 @@ func (s *SocketOperations) State() uint32 { } // Type implements socket.Socket.Type. -func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { +func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) { return s.family, s.skType, s.protocol } + +// LINT.ThenChange(./netstack_vfs2.go) diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go new file mode 100644 index 000000000..3335e7430 --- /dev/null +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -0,0 +1,332 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netstack + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/amutex" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" + "gvisor.dev/gvisor/pkg/sentry/inet" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" + "gvisor.dev/gvisor/tools/go_marshal/primitive" +) + +// SocketVFS2 encapsulates all the state needed to represent a network stack +// endpoint in the kernel context. +type SocketVFS2 struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + vfs.LockFD + + socketOpsCommon +} + +var _ = socket.SocketVFS2(&SocketVFS2{}) + +// NewVFS2 creates a new endpoint socket. +func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*vfs.FileDescription, *syserr.Error) { + if skType == linux.SOCK_STREAM { + if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil { + return nil, syserr.TranslateNetstackError(err) + } + } + + mnt := t.Kernel().SocketMount() + d := sockfs.NewDentry(t.Credentials(), mnt) + + s := &SocketVFS2{ + socketOpsCommon: socketOpsCommon{ + Queue: queue, + family: family, + Endpoint: endpoint, + skType: skType, + protocol: protocol, + }, + } + s.LockFD.Init(&vfs.FileLocks{}) + vfsfd := &s.vfsfd + if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{ + DenyPRead: true, + DenyPWrite: true, + UseDentryMetadata: true, + }); err != nil { + return nil, syserr.FromError(err) + } + return vfsfd, nil +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { + return s.socketOpsCommon.Readiness(mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + s.socketOpsCommon.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *SocketVFS2) EventUnregister(e *waiter.Entry) { + s.socketOpsCommon.EventUnregister(e) +} + +// Read implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + if dst.NumBytes() == 0 { + return 0, nil + } + n, _, _, _, _, err := s.nonBlockingRead(ctx, dst, false, false, false) + if err == syserr.ErrWouldBlock { + return int64(n), syserror.ErrWouldBlock + } + if err != nil { + return 0, err.ToError() + } + return int64(n), nil +} + +// Write implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + f := &ioSequencePayload{ctx: ctx, src: src} + n, resCh, err := s.Endpoint.Write(f, tcpip.WriteOptions{}) + if err == tcpip.ErrWouldBlock { + return 0, syserror.ErrWouldBlock + } + + if resCh != nil { + if err := amutex.Block(ctx, resCh); err != nil { + return 0, err + } + n, _, err = s.Endpoint.Write(f, tcpip.WriteOptions{}) + } + + if err != nil { + return 0, syserr.TranslateNetstackError(err).ToError() + } + + if int64(n) < src.NumBytes() { + return int64(n), syserror.ErrWouldBlock + } + + return int64(n), nil +} + +// Accept implements the linux syscall accept(2) for sockets backed by +// tcpip.Endpoint. +func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { + // Issue the accept request to get the new endpoint. + ep, wq, terr := s.Endpoint.Accept() + if terr != nil { + if terr != tcpip.ErrWouldBlock || !blocking { + return 0, nil, 0, syserr.TranslateNetstackError(terr) + } + + var err *syserr.Error + ep, wq, err = s.blockingAccept(t) + if err != nil { + return 0, nil, 0, err + } + } + + ns, err := NewVFS2(t, s.family, s.skType, s.protocol, wq, ep) + if err != nil { + return 0, nil, 0, err + } + defer ns.DecRef(t) + + if err := ns.SetStatusFlags(t, t.Credentials(), uint32(flags&linux.SOCK_NONBLOCK)); err != nil { + return 0, nil, 0, syserr.FromError(err) + } + + var addr linux.SockAddr + var addrLen uint32 + if peerRequested { + // Get address of the peer and write it to peer slice. + var err *syserr.Error + addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t) + if err != nil { + return 0, nil, 0, err + } + } + + fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ + CloseOnExec: flags&linux.SOCK_CLOEXEC != 0, + }) + + t.Kernel().RecordSocketVFS2(ns) + + return fd, addr, addrLen, syserr.FromError(e) +} + +// Ioctl implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return s.socketOpsCommon.ioctl(ctx, uio, args) +} + +// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by +// tcpip.Endpoint. +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { + // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is + // implemented specifically for netstack.SocketVFS2 rather than + // commonEndpoint. commonEndpoint should be extended to support socket + // options where the implementation is not shared, as unix sockets need + // their own support for SO_TIMESTAMP. + if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP { + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + val := primitive.Int32(0) + s.readMu.Lock() + defer s.readMu.Unlock() + if s.sockOptTimestamp { + val = 1 + } + return &val, nil + } + if level == linux.SOL_TCP && name == linux.TCP_INQ { + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + val := primitive.Int32(0) + s.readMu.Lock() + defer s.readMu.Unlock() + if s.sockOptInq { + val = 1 + } + return &val, nil + } + + if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { + switch name { + case linux.IPT_SO_GET_INFO: + if outLen < linux.SizeOfIPTGetinfo { + return nil, syserr.ErrInvalidArgument + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr) + if err != nil { + return nil, err + } + return &info, nil + + case linux.IPT_SO_GET_ENTRIES: + if outLen < linux.SizeOfIPTGetEntries { + return nil, syserr.ErrInvalidArgument + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen) + if err != nil { + return nil, err + } + return &entries, nil + + } + } + + return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen) +} + +// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by +// tcpip.Endpoint. +func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error { + // TODO(b/78348848): Unlike other socket options, SO_TIMESTAMP is + // implemented specifically for netstack.SocketVFS2 rather than + // commonEndpoint. commonEndpoint should be extended to support socket + // options where the implementation is not shared, as unix sockets need + // their own support for SO_TIMESTAMP. + if level == linux.SOL_SOCKET && name == linux.SO_TIMESTAMP { + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + s.readMu.Lock() + defer s.readMu.Unlock() + s.sockOptTimestamp = usermem.ByteOrder.Uint32(optVal) != 0 + return nil + } + if level == linux.SOL_TCP && name == linux.TCP_INQ { + if len(optVal) < sizeOfInt32 { + return syserr.ErrInvalidArgument + } + s.readMu.Lock() + defer s.readMu.Unlock() + s.sockOptInq = usermem.ByteOrder.Uint32(optVal) != 0 + return nil + } + + if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP { + switch name { + case linux.IPT_SO_SET_REPLACE: + if len(optVal) < linux.SizeOfIPTReplace { + return syserr.ErrInvalidArgument + } + + stack := inet.StackFromContext(t) + if stack == nil { + return syserr.ErrNoDevice + } + // Stack must be a netstack stack. + return netfilter.SetEntries(stack.(*Stack).Stack, optVal) + + case linux.IPT_SO_SET_ADD_COUNTERS: + // TODO(gvisor.dev/issue/170): Counter support. + return nil + } + } + + return SetSockOpt(t, s, s.Endpoint, level, name, optVal) +} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go index 2d2c1ba2a..ead3b2b79 100644 --- a/pkg/sentry/socket/netstack/provider.go +++ b/pkg/sentry/socket/netstack/provider.go @@ -18,7 +18,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -33,6 +33,8 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// LINT.IfChange + // provider is an inet socket provider. type provider struct { family int @@ -62,10 +64,6 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in } case linux.SOCK_RAW: - // TODO(b/142504697): "In order to create a raw socket, a - // process must have the CAP_NET_RAW capability in the user - // namespace that governs its network namespace." - raw(7) - // Raw sockets require CAP_NET_RAW. creds := auth.CredentialsFromContext(ctx) if !creds.HasCapability(linux.CAP_NET_RAW) { @@ -75,6 +73,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in switch protocol { case syscall.IPPROTO_ICMP: return header.ICMPv4ProtocolNumber, true, nil + case syscall.IPPROTO_ICMPV6: + return header.ICMPv6ProtocolNumber, true, nil case syscall.IPPROTO_UDP: return header.UDPProtocolNumber, true, nil case syscall.IPPROTO_TCP: @@ -124,6 +124,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated) } else { ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq) + + // Assign task to PacketOwner interface to get the UID and GID for + // iptables owner matching. + if e == nil { + ep.SetOwner(t) + } } if e != nil { return nil, syserr.TranslateNetstackError(e) @@ -133,10 +139,6 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (* } func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) { - // TODO(b/142504697): "In order to create a packet socket, a process - // must have the CAP_NET_RAW capability in the user namespace that - // governs its network namespace." - packet(7) - // Packet sockets require CAP_NET_RAW. creds := auth.CredentialsFromContext(t) if !creds.HasCapability(linux.CAP_NET_RAW) { @@ -167,6 +169,8 @@ func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol return New(t, linux.AF_PACKET, stype, protocol, wq, ep) } +// LINT.ThenChange(./provider_vfs2.go) + // Pair just returns nil sockets (not supported). func (*provider) Pair(*kernel.Task, linux.SockType, int) (*fs.File, *fs.File, *syserr.Error) { return nil, nil, nil diff --git a/pkg/sentry/socket/netstack/provider_vfs2.go b/pkg/sentry/socket/netstack/provider_vfs2.go new file mode 100644 index 000000000..2a01143f6 --- /dev/null +++ b/pkg/sentry/socket/netstack/provider_vfs2.go @@ -0,0 +1,141 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netstack + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/waiter" +) + +// providerVFS2 is an inet socket provider. +type providerVFS2 struct { + family int + netProto tcpip.NetworkProtocolNumber +} + +// Socket creates a new socket object for the AF_INET, AF_INET6, or AF_PACKET +// family. +func (p *providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) { + // Fail right away if we don't have a stack. + stack := t.NetworkContext() + if stack == nil { + // Don't propagate an error here. Instead, allow the socket + // code to continue searching for another provider. + return nil, nil + } + eps, ok := stack.(*Stack) + if !ok { + return nil, nil + } + + // Packet sockets are handled separately, since they are neither INET + // nor INET6 specific. + if p.family == linux.AF_PACKET { + return packetSocketVFS2(t, eps, stype, protocol) + } + + // Figure out the transport protocol. + transProto, associated, err := getTransportProtocol(t, stype, protocol) + if err != nil { + return nil, err + } + + // Create the endpoint. + var ep tcpip.Endpoint + var e *tcpip.Error + wq := &waiter.Queue{} + if stype == linux.SOCK_RAW { + ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated) + } else { + ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq) + + // Assign task to PacketOwner interface to get the UID and GID for + // iptables owner matching. + if e == nil { + ep.SetOwner(t) + } + } + if e != nil { + return nil, syserr.TranslateNetstackError(e) + } + + return NewVFS2(t, p.family, stype, int(transProto), wq, ep) +} + +func packetSocketVFS2(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) { + // Packet sockets require CAP_NET_RAW. + creds := auth.CredentialsFromContext(t) + if !creds.HasCapability(linux.CAP_NET_RAW) { + return nil, syserr.ErrNotPermitted + } + + // "cooked" packets don't contain link layer information. + var cooked bool + switch stype { + case linux.SOCK_DGRAM: + cooked = true + case linux.SOCK_RAW: + cooked = false + default: + return nil, syserr.ErrProtocolNotSupported + } + + // protocol is passed in network byte order, but netstack wants it in + // host order. + netProto := tcpip.NetworkProtocolNumber(ntohs(uint16(protocol))) + + wq := &waiter.Queue{} + ep, err := epStack.Stack.NewPacketEndpoint(cooked, netProto, wq) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + + return NewVFS2(t, linux.AF_PACKET, stype, protocol, wq, ep) +} + +// Pair just returns nil sockets (not supported). +func (*providerVFS2) Pair(*kernel.Task, linux.SockType, int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) { + return nil, nil, nil +} + +// init registers socket providers for AF_INET, AF_INET6, and AF_PACKET. +func init() { + // Providers backed by netstack. + p := []providerVFS2{ + { + family: linux.AF_INET, + netProto: ipv4.ProtocolNumber, + }, + { + family: linux.AF_INET6, + netProto: ipv6.ProtocolNumber, + }, + { + family: linux.AF_PACKET, + }, + } + + for i := range p { + socket.RegisterProviderVFS2(p[i].family, &p[i]) + } +} diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index d0102cfa3..f9097d6b2 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -15,14 +15,15 @@ package netstack import ( + "fmt" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/sentry/socket/netfilter" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" - "gvisor.dev/gvisor/pkg/tcpip/iptables" "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" "gvisor.dev/gvisor/pkg/tcpip/stack" @@ -41,19 +42,29 @@ func (s *Stack) SupportsIPv6() bool { return s.Stack.CheckNetworkProtocol(ipv6.ProtocolNumber) } +// Converts Netstack's ARPHardwareType to equivalent linux constants. +func toLinuxARPHardwareType(t header.ARPHardwareType) uint16 { + switch t { + case header.ARPHardwareNone: + return linux.ARPHRD_NONE + case header.ARPHardwareLoopback: + return linux.ARPHRD_LOOPBACK + case header.ARPHardwareEther: + return linux.ARPHRD_ETHER + default: + panic(fmt.Sprintf("unknown ARPHRD type: %d", t)) + } +} + // Interfaces implements inet.Stack.Interfaces. func (s *Stack) Interfaces() map[int32]inet.Interface { is := make(map[int32]inet.Interface) for id, ni := range s.Stack.NICInfo() { - var devType uint16 - if ni.Flags.Loopback { - devType = linux.ARPHRD_LOOPBACK - } is[int32(id)] = inet.Interface{ Name: ni.Name, Addr: []byte(ni.LinkAddress), Flags: uint32(nicStateFlagsToLinux(ni.Flags)), - DeviceType: devType, + DeviceType: toLinuxARPHardwareType(ni.ARPHardwareType), MTU: ni.MTU, } } @@ -89,6 +100,59 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { return nicAddrs } +// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. +func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + var ( + protocol tcpip.NetworkProtocolNumber + address tcpip.Address + ) + switch addr.Family { + case linux.AF_INET: + if len(addr.Addr) < header.IPv4AddressSize { + return syserror.EINVAL + } + if addr.PrefixLen > header.IPv4AddressSize*8 { + return syserror.EINVAL + } + protocol = ipv4.ProtocolNumber + address = tcpip.Address(addr.Addr[:header.IPv4AddressSize]) + + case linux.AF_INET6: + if len(addr.Addr) < header.IPv6AddressSize { + return syserror.EINVAL + } + if addr.PrefixLen > header.IPv6AddressSize*8 { + return syserror.EINVAL + } + protocol = ipv6.ProtocolNumber + address = tcpip.Address(addr.Addr[:header.IPv6AddressSize]) + + default: + return syserror.ENOTSUP + } + + protocolAddress := tcpip.ProtocolAddress{ + Protocol: protocol, + AddressWithPrefix: tcpip.AddressWithPrefix{ + Address: address, + PrefixLen: int(addr.PrefixLen), + }, + } + + // Attach address to interface. + if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + return syserr.TranslateNetstackError(err).ToError() + } + + // Add route for local network. + s.Stack.AddRoute(tcpip.Route{ + Destination: protocolAddress.AddressWithPrefix.Subnet(), + Gateway: "", // No gateway for local network. + NIC: tcpip.NICID(idx), + }) + return nil +} + // TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize. func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { var rs tcp.ReceiveBufferSizeOption @@ -143,39 +207,83 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error { return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError() } +// TCPRecovery implements inet.Stack.TCPRecovery. +func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { + var recovery tcp.Recovery + if err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &recovery); err != nil { + return 0, syserr.TranslateNetstackError(err).ToError() + } + return inet.TCPLossRecovery(recovery), nil +} + +// SetTCPRecovery implements inet.Stack.SetTCPRecovery. +func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { + return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.Recovery(recovery))).ToError() +} + // Statistics implements inet.Stack.Statistics. func (s *Stack) Statistics(stat interface{}, arg string) error { switch stats := stat.(type) { + case *inet.StatDev: + for _, ni := range s.Stack.NICInfo() { + if ni.Name != arg { + continue + } + // TODO(gvisor.dev/issue/2103) Support stubbed stats. + *stats = inet.StatDev{ + // Receive section. + ni.Stats.Rx.Bytes.Value(), // bytes. + ni.Stats.Rx.Packets.Value(), // packets. + 0, // errs. + 0, // drop. + 0, // fifo. + 0, // frame. + 0, // compressed. + 0, // multicast. + // Transmit section. + ni.Stats.Tx.Bytes.Value(), // bytes. + ni.Stats.Tx.Packets.Value(), // packets. + 0, // errs. + 0, // drop. + 0, // fifo. + 0, // colls. + 0, // carrier. + 0, // compressed. + } + break + } case *inet.StatSNMPIP: ip := Metrics.IP + // TODO(gvisor.dev/issue/969) Support stubbed stats. *stats = inet.StatSNMPIP{ - 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding. - 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL. - ip.PacketsReceived.Value(), // InReceives. - 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors. - ip.InvalidAddressesReceived.Value(), // InAddrErrors. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams. - 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos. - 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards. - ip.PacketsDelivered.Value(), // InDelivers. - ip.PacketsSent.Value(), // OutRequests. - ip.OutgoingPacketErrors.Value(), // OutDiscards. - 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs. - 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails. - 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs. - 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails. - 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates. + 0, // Ip/Forwarding. + 0, // Ip/DefaultTTL. + ip.PacketsReceived.Value(), // InReceives. + 0, // Ip/InHdrErrors. + ip.InvalidDestinationAddressesReceived.Value(), // InAddrErrors. + 0, // Ip/ForwDatagrams. + 0, // Ip/InUnknownProtos. + 0, // Ip/InDiscards. + ip.PacketsDelivered.Value(), // InDelivers. + ip.PacketsSent.Value(), // OutRequests. + ip.OutgoingPacketErrors.Value(), // OutDiscards. + 0, // Ip/OutNoRoutes. + 0, // Support Ip/ReasmTimeout. + 0, // Support Ip/ReasmReqds. + 0, // Support Ip/ReasmOKs. + 0, // Support Ip/ReasmFails. + 0, // Support Ip/FragOKs. + 0, // Support Ip/FragFails. + 0, // Support Ip/FragCreates. } case *inet.StatSNMPICMP: in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats + // TODO(gvisor.dev/issue/969) Support stubbed stats. *stats = inet.StatSNMPICMP{ - 0, // TODO(gvisor.dev/issue/969): Support Icmp/InMsgs. + 0, // Icmp/InMsgs. Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors. - 0, // TODO(gvisor.dev/issue/969): Support Icmp/InCsumErrors. + 0, // Icmp/InCsumErrors. in.DstUnreachable.Value(), // InDestUnreachs. in.TimeExceeded.Value(), // InTimeExcds. in.ParamProblem.Value(), // InParmProbs. @@ -187,7 +295,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { in.TimestampReply.Value(), // InTimestampReps. in.InfoRequest.Value(), // InAddrMasks. in.InfoReply.Value(), // InAddrMaskReps. - 0, // TODO(gvisor.dev/issue/969): Support Icmp/OutMsgs. + 0, // Icmp/OutMsgs. Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors. out.DstUnreachable.Value(), // OutDestUnreachs. out.TimeExceeded.Value(), // OutTimeExcds. @@ -223,15 +331,16 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { } case *inet.StatSNMPUDP: udp := Metrics.UDP + // TODO(gvisor.dev/issue/969) Support stubbed stats. *stats = inet.StatSNMPUDP{ udp.PacketsReceived.Value(), // InDatagrams. udp.UnknownPortErrors.Value(), // NoPorts. - 0, // TODO(gvisor.dev/issue/969): Support Udp/InErrors. + 0, // Udp/InErrors. udp.PacketsSent.Value(), // OutDatagrams. udp.ReceiveBufferErrors.Value(), // RcvbufErrors. - 0, // TODO(gvisor.dev/issue/969): Support Udp/SndbufErrors. - 0, // TODO(gvisor.dev/issue/969): Support Udp/InCsumErrors. - 0, // TODO(gvisor.dev/issue/969): Support Udp/IgnoredMulti. + 0, // Udp/SndbufErrors. + udp.ChecksumErrors.Value(), // Udp/InCsumErrors. + 0, // Udp/IgnoredMulti. } default: return syserr.ErrEndpointOperation.ToError() @@ -278,21 +387,30 @@ func (s *Stack) RouteTable() []inet.Route { } // IPTables returns the stack's iptables. -func (s *Stack) IPTables() (iptables.IPTables, error) { +func (s *Stack) IPTables() (*stack.IPTables, error) { return s.Stack.IPTables(), nil } -// FillDefaultIPTables sets the stack's iptables to the default tables, which -// allow and do not modify all traffic. -func (s *Stack) FillDefaultIPTables() { - netfilter.FillDefaultIPTables(s.Stack) -} - // Resume implements inet.Stack.Resume. func (s *Stack) Resume() { s.Stack.Resume() } +// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { + return s.Stack.RegisteredEndpoints() +} + +// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { + return s.Stack.CleanupEndpoints() +} + +// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +func (s *Stack) RestoreCleanupEndpoints(es []stack.TransportEndpoint) { + s.Stack.RestoreCleanupEndpoints(es) +} + // Forwarding implements inet.Stack.Forwarding. func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { switch protocol { diff --git a/pkg/sentry/socket/rpcinet/BUILD b/pkg/sentry/socket/rpcinet/BUILD deleted file mode 100644 index 3a6baa308..000000000 --- a/pkg/sentry/socket/rpcinet/BUILD +++ /dev/null @@ -1,68 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") -load("@io_bazel_rules_go//proto:def.bzl", "go_proto_library") -load("@rules_cc//cc:defs.bzl", "cc_proto_library") - -package(licenses = ["notice"]) - -go_library( - name = "rpcinet", - srcs = [ - "device.go", - "rpcinet.go", - "socket.go", - "stack.go", - "stack_unsafe.go", - ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet", - visibility = ["//pkg/sentry:internal"], - deps = [ - ":syscall_rpc_go_proto", - "//pkg/abi/linux", - "//pkg/binary", - "//pkg/sentry/arch", - "//pkg/sentry/context", - "//pkg/sentry/device", - "//pkg/sentry/fs", - "//pkg/sentry/fs/fsutil", - "//pkg/sentry/inet", - "//pkg/sentry/kernel", - "//pkg/sentry/kernel/time", - "//pkg/sentry/socket", - "//pkg/sentry/socket/hostinet", - "//pkg/sentry/socket/rpcinet/conn", - "//pkg/sentry/socket/rpcinet/notifier", - "//pkg/sentry/unimpl", - "//pkg/sentry/usermem", - "//pkg/syserr", - "//pkg/syserror", - "//pkg/tcpip", - "//pkg/tcpip/buffer", - "//pkg/unet", - "//pkg/waiter", - ], -) - -proto_library( - name = "syscall_rpc_proto", - srcs = ["syscall_rpc.proto"], - visibility = [ - "//visibility:public", - ], -) - -cc_proto_library( - name = "syscall_rpc_cc_proto", - visibility = [ - "//visibility:public", - ], - deps = [":syscall_rpc_proto"], -) - -go_proto_library( - name = "syscall_rpc_go_proto", - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto", - proto = ":syscall_rpc_proto", - visibility = [ - "//visibility:public", - ], -) diff --git a/pkg/sentry/socket/rpcinet/conn/BUILD b/pkg/sentry/socket/rpcinet/conn/BUILD deleted file mode 100644 index 23eadcb1b..000000000 --- a/pkg/sentry/socket/rpcinet/conn/BUILD +++ /dev/null @@ -1,17 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "conn", - srcs = ["conn.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn", - visibility = ["//pkg/sentry:internal"], - deps = [ - "//pkg/binary", - "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto", - "//pkg/syserr", - "//pkg/unet", - "@com_github_golang_protobuf//proto:go_default_library", - ], -) diff --git a/pkg/sentry/socket/rpcinet/conn/conn.go b/pkg/sentry/socket/rpcinet/conn/conn.go deleted file mode 100644 index 356adad99..000000000 --- a/pkg/sentry/socket/rpcinet/conn/conn.go +++ /dev/null @@ -1,187 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package conn is an RPC connection to a syscall RPC server. -package conn - -import ( - "fmt" - "sync" - "sync/atomic" - "syscall" - - "github.com/golang/protobuf/proto" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/unet" - - pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto" -) - -type request struct { - response []byte - ready chan struct{} - ignoreResult bool -} - -// RPCConnection represents a single RPC connection to a syscall gofer. -type RPCConnection struct { - // reqID is the ID of the last request and must be accessed atomically. - reqID uint64 - - sendMu sync.Mutex - socket *unet.Socket - - reqMu sync.Mutex - requests map[uint64]request -} - -// NewRPCConnection initializes a RPC connection to a socket gofer. -func NewRPCConnection(s *unet.Socket) *RPCConnection { - conn := &RPCConnection{socket: s, requests: map[uint64]request{}} - go func() { // S/R-FIXME(b/77962828) - var nums [16]byte - for { - for n := 0; n < len(nums); { - nn, err := conn.socket.Read(nums[n:]) - if err != nil { - panic(fmt.Sprint("error reading length from socket rpc gofer: ", err)) - } - n += nn - } - - b := make([]byte, binary.LittleEndian.Uint64(nums[:8])) - id := binary.LittleEndian.Uint64(nums[8:]) - - for n := 0; n < len(b); { - nn, err := conn.socket.Read(b[n:]) - if err != nil { - panic(fmt.Sprint("error reading request from socket rpc gofer: ", err)) - } - n += nn - } - - conn.reqMu.Lock() - r := conn.requests[id] - if r.ignoreResult { - delete(conn.requests, id) - } else { - r.response = b - conn.requests[id] = r - } - conn.reqMu.Unlock() - close(r.ready) - } - }() - return conn -} - -// NewRequest makes a request to the RPC gofer and returns the request ID and a -// channel which will be closed once the request completes. -func (c *RPCConnection) NewRequest(req pb.SyscallRequest, ignoreResult bool) (uint64, chan struct{}) { - b, err := proto.Marshal(&req) - if err != nil { - panic(fmt.Sprint("invalid proto: ", err)) - } - - id := atomic.AddUint64(&c.reqID, 1) - ch := make(chan struct{}) - - c.reqMu.Lock() - c.requests[id] = request{ready: ch, ignoreResult: ignoreResult} - c.reqMu.Unlock() - - c.sendMu.Lock() - defer c.sendMu.Unlock() - - var nums [16]byte - binary.LittleEndian.PutUint64(nums[:8], uint64(len(b))) - binary.LittleEndian.PutUint64(nums[8:], id) - for n := 0; n < len(nums); { - nn, err := c.socket.Write(nums[n:]) - if err != nil { - panic(fmt.Sprint("error writing length and ID to socket gofer: ", err)) - } - n += nn - } - - for n := 0; n < len(b); { - nn, err := c.socket.Write(b[n:]) - if err != nil { - panic(fmt.Sprint("error writing request to socket gofer: ", err)) - } - n += nn - } - - return id, ch -} - -// RPCReadFile will execute the ReadFile helper RPC method which avoids the -// common pattern of open(2), read(2), close(2) by doing all three operations -// as a single RPC. It will read the entire file or return EFBIG if the file -// was too large. -func (c *RPCConnection) RPCReadFile(path string) ([]byte, *syserr.Error) { - req := &pb.SyscallRequest_ReadFile{&pb.ReadFileRequest{ - Path: path, - }} - - id, ch := c.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */) - <-ch - - res := c.Request(id).Result.(*pb.SyscallResponse_ReadFile).ReadFile.Result - if e, ok := res.(*pb.ReadFileResponse_ErrorNumber); ok { - return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber)) - } - - return res.(*pb.ReadFileResponse_Data).Data, nil -} - -// RPCWriteFile will execute the WriteFile helper RPC method which avoids the -// common pattern of open(2), write(2), write(2), close(2) by doing all -// operations as a single RPC. -func (c *RPCConnection) RPCWriteFile(path string, data []byte) (int64, *syserr.Error) { - req := &pb.SyscallRequest_WriteFile{&pb.WriteFileRequest{ - Path: path, - Content: data, - }} - - id, ch := c.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */) - <-ch - - res := c.Request(id).Result.(*pb.SyscallResponse_WriteFile).WriteFile - if e := res.ErrorNumber; e != 0 { - return int64(res.Written), syserr.FromHost(syscall.Errno(e)) - } - - return int64(res.Written), nil -} - -// Request retrieves the request corresponding to the given request ID. -// -// The channel returned by NewRequest must have been closed before Request can -// be called. This will happen automatically, do not manually close the -// channel. -func (c *RPCConnection) Request(id uint64) pb.SyscallResponse { - c.reqMu.Lock() - r := c.requests[id] - delete(c.requests, id) - c.reqMu.Unlock() - - var resp pb.SyscallResponse - if err := proto.Unmarshal(r.response, &resp); err != nil { - panic(fmt.Sprint("invalid proto: ", err)) - } - - return resp -} diff --git a/pkg/sentry/socket/rpcinet/notifier/BUILD b/pkg/sentry/socket/rpcinet/notifier/BUILD deleted file mode 100644 index a3585e10d..000000000 --- a/pkg/sentry/socket/rpcinet/notifier/BUILD +++ /dev/null @@ -1,16 +0,0 @@ -load("//tools/go_stateify:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "notifier", - srcs = ["notifier.go"], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier", - visibility = ["//:sandbox"], - deps = [ - "//pkg/sentry/socket/rpcinet:syscall_rpc_go_proto", - "//pkg/sentry/socket/rpcinet/conn", - "//pkg/waiter", - "@org_golang_x_sys//unix:go_default_library", - ], -) diff --git a/pkg/sentry/socket/rpcinet/notifier/notifier.go b/pkg/sentry/socket/rpcinet/notifier/notifier.go deleted file mode 100644 index 7efe4301f..000000000 --- a/pkg/sentry/socket/rpcinet/notifier/notifier.go +++ /dev/null @@ -1,231 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package notifier implements an FD notifier implementation over RPC. -package notifier - -import ( - "fmt" - "sync" - "syscall" - - "golang.org/x/sys/unix" - "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn" - pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto" - "gvisor.dev/gvisor/pkg/waiter" -) - -type fdInfo struct { - queue *waiter.Queue - waiting bool -} - -// Notifier holds all the state necessary to issue notifications when IO events -// occur in the observed FDs. -type Notifier struct { - // rpcConn is the connection that is used for sending RPCs. - rpcConn *conn.RPCConnection - - // epFD is the epoll file descriptor used to register for io - // notifications. - epFD uint32 - - // mu protects fdMap. - mu sync.Mutex - - // fdMap maps file descriptors to their notification queues and waiting - // status. - fdMap map[uint32]*fdInfo -} - -// NewRPCNotifier creates a new notifier object. -func NewRPCNotifier(cn *conn.RPCConnection) (*Notifier, error) { - id, c := cn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCreate1{&pb.EpollCreate1Request{}}}, false /* ignoreResult */) - <-c - - res := cn.Request(id).Result.(*pb.SyscallResponse_EpollCreate1).EpollCreate1.Result - if e, ok := res.(*pb.EpollCreate1Response_ErrorNumber); ok { - return nil, syscall.Errno(e.ErrorNumber) - } - - w := &Notifier{ - rpcConn: cn, - epFD: res.(*pb.EpollCreate1Response_Fd).Fd, - fdMap: make(map[uint32]*fdInfo), - } - - go w.waitAndNotify() // S/R-FIXME(b/77962828) - - return w, nil -} - -// waitFD waits on mask for fd. The fdMap mutex must be hold. -func (n *Notifier) waitFD(fd uint32, fi *fdInfo, mask waiter.EventMask) error { - if !fi.waiting && mask == 0 { - return nil - } - - e := pb.EpollEvent{ - Events: mask.ToLinux() | unix.EPOLLET, - Fd: fd, - } - - switch { - case !fi.waiting && mask != 0: - id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_ADD, Fd: fd, Event: &e}}}, false /* ignoreResult */) - <-c - - e := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollCtl).EpollCtl.ErrorNumber - if e != 0 { - return syscall.Errno(e) - } - - fi.waiting = true - case fi.waiting && mask == 0: - id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_DEL, Fd: fd}}}, false /* ignoreResult */) - <-c - n.rpcConn.Request(id) - - fi.waiting = false - case fi.waiting && mask != 0: - id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollCtl{&pb.EpollCtlRequest{Epfd: n.epFD, Op: syscall.EPOLL_CTL_MOD, Fd: fd, Event: &e}}}, false /* ignoreResult */) - <-c - - e := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollCtl).EpollCtl.ErrorNumber - if e != 0 { - return syscall.Errno(e) - } - } - - return nil -} - -// addFD adds an FD to the list of FDs observed by n. -func (n *Notifier) addFD(fd uint32, queue *waiter.Queue) { - n.mu.Lock() - defer n.mu.Unlock() - - // Panic if we're already notifying on this FD. - if _, ok := n.fdMap[fd]; ok { - panic(fmt.Sprintf("File descriptor %d added twice", fd)) - } - - // We have nothing to wait for at the moment. Just add it to the map. - n.fdMap[fd] = &fdInfo{queue: queue} -} - -// updateFD updates the set of events the FD needs to be notified on. -func (n *Notifier) updateFD(fd uint32) error { - n.mu.Lock() - defer n.mu.Unlock() - - if fi, ok := n.fdMap[fd]; ok { - return n.waitFD(fd, fi, fi.queue.Events()) - } - - return nil -} - -// RemoveFD removes an FD from the list of FDs observed by n. -func (n *Notifier) removeFD(fd uint32) { - n.mu.Lock() - defer n.mu.Unlock() - - // Remove from map, then from epoll object. - n.waitFD(fd, n.fdMap[fd], 0) - delete(n.fdMap, fd) -} - -// hasFD returns true if the FD is in the list of observed FDs. -func (n *Notifier) hasFD(fd uint32) bool { - n.mu.Lock() - defer n.mu.Unlock() - - _, ok := n.fdMap[fd] - return ok -} - -// waitAndNotify loops waiting for io event notifications from the epoll -// object. Once notifications arrive, they are dispatched to the -// registered queue. -func (n *Notifier) waitAndNotify() error { - for { - id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_EpollWait{&pb.EpollWaitRequest{Fd: n.epFD, NumEvents: 100, Msec: -1}}}, false /* ignoreResult */) - <-c - - res := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_EpollWait).EpollWait.Result - if e, ok := res.(*pb.EpollWaitResponse_ErrorNumber); ok { - err := syscall.Errno(e.ErrorNumber) - // NOTE(magi): I don't think epoll_wait can return EAGAIN but I'm being - // conseratively careful here since exiting the notification thread - // would be really bad. - if err == syscall.EINTR || err == syscall.EAGAIN { - continue - } - return err - } - - n.mu.Lock() - for _, e := range res.(*pb.EpollWaitResponse_Events).Events.Events { - if fi, ok := n.fdMap[e.Fd]; ok { - fi.queue.Notify(waiter.EventMaskFromLinux(e.Events)) - } - } - n.mu.Unlock() - } -} - -// AddFD adds an FD to the list of observed FDs. -func (n *Notifier) AddFD(fd uint32, queue *waiter.Queue) error { - n.addFD(fd, queue) - return nil -} - -// UpdateFD updates the set of events the FD needs to be notified on. -func (n *Notifier) UpdateFD(fd uint32) error { - return n.updateFD(fd) -} - -// RemoveFD removes an FD from the list of observed FDs. -func (n *Notifier) RemoveFD(fd uint32) { - n.removeFD(fd) -} - -// HasFD returns true if the FD is in the list of observed FDs. -// -// This should only be used by tests to assert that FDs are correctly -// registered. -func (n *Notifier) HasFD(fd uint32) bool { - return n.hasFD(fd) -} - -// NonBlockingPoll polls the given fd in non-blocking fashion. It is used just -// to query the FD's current state; this method will block on the RPC response -// although the syscall is non-blocking. -func (n *Notifier) NonBlockingPoll(fd uint32, mask waiter.EventMask) waiter.EventMask { - for { - id, c := n.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Poll{&pb.PollRequest{Fd: fd, Events: mask.ToLinux()}}}, false /* ignoreResult */) - <-c - - res := n.rpcConn.Request(id).Result.(*pb.SyscallResponse_Poll).Poll.Result - if e, ok := res.(*pb.PollResponse_ErrorNumber); ok { - if syscall.Errno(e.ErrorNumber) == syscall.EINTR { - continue - } - return mask - } - - return waiter.EventMaskFromLinux(res.(*pb.PollResponse_Events).Events) - } -} diff --git a/pkg/sentry/socket/rpcinet/rpcinet.go b/pkg/sentry/socket/rpcinet/rpcinet.go deleted file mode 100644 index 5d4fd4dac..000000000 --- a/pkg/sentry/socket/rpcinet/rpcinet.go +++ /dev/null @@ -1,16 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package rpcinet implements sockets using an RPC for each syscall. -package rpcinet diff --git a/pkg/sentry/socket/rpcinet/socket.go b/pkg/sentry/socket/rpcinet/socket.go deleted file mode 100644 index ddb76d9d4..000000000 --- a/pkg/sentry/socket/rpcinet/socket.go +++ /dev/null @@ -1,909 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package rpcinet - -import ( - "sync/atomic" - "syscall" - "time" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/fs" - "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" - "gvisor.dev/gvisor/pkg/sentry/kernel" - ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" - "gvisor.dev/gvisor/pkg/sentry/socket" - "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn" - "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier" - pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto" - "gvisor.dev/gvisor/pkg/sentry/unimpl" - "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/tcpip/buffer" - "gvisor.dev/gvisor/pkg/waiter" -) - -// socketOperations implements fs.FileOperations and socket.Socket for a socket -// implemented using a host socket. -type socketOperations struct { - fsutil.FilePipeSeek `state:"nosave"` - fsutil.FileNotDirReaddir `state:"nosave"` - fsutil.FileNoFsync `state:"nosave"` - fsutil.FileNoMMap `state:"nosave"` - fsutil.FileNoSplice `state:"nosave"` - fsutil.FileNoopFlush `state:"nosave"` - fsutil.FileUseInodeUnstableAttr `state:"nosave"` - socket.SendReceiveTimeout - - family int // Read-only. - stype linux.SockType // Read-only. - protocol int // Read-only. - - fd uint32 // must be O_NONBLOCK - wq *waiter.Queue - rpcConn *conn.RPCConnection - notifier *notifier.Notifier - - // shState is the state of the connection with respect to shutdown. Because - // we're mixing non-blocking semantics on the other side we have to adapt for - // some strange differences between blocking and non-blocking sockets. - shState int32 -} - -// Verify that we actually implement socket.Socket. -var _ = socket.Socket(&socketOperations{}) - -// New creates a new RPC socket. -func newSocketFile(ctx context.Context, stack *Stack, family int, skType linux.SockType, protocol int) (*fs.File, *syserr.Error) { - id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(family), Type: int64(skType | syscall.SOCK_NONBLOCK), Protocol: int64(protocol)}}}, false /* ignoreResult */) - <-c - - res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Socket).Socket.Result - if e, ok := res.(*pb.SocketResponse_ErrorNumber); ok { - return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber)) - } - fd := res.(*pb.SocketResponse_Fd).Fd - - var wq waiter.Queue - stack.notifier.AddFD(fd, &wq) - - dirent := socket.NewDirent(ctx, socketDevice) - defer dirent.DecRef() - return fs.NewFile(ctx, dirent, fs.FileFlags{Read: true, Write: true}, &socketOperations{ - family: family, - stype: skType, - protocol: protocol, - wq: &wq, - fd: fd, - rpcConn: stack.rpcConn, - notifier: stack.notifier, - }), nil -} - -func isBlockingErrno(err error) bool { - return err == syscall.EAGAIN || err == syscall.EWOULDBLOCK -} - -func translateIOSyscallError(err error) error { - if isBlockingErrno(err) { - return syserror.ErrWouldBlock - } - return err -} - -// setShutdownFlags will set the shutdown flag so we can handle blocking reads -// after a read shutdown. -func (s *socketOperations) setShutdownFlags(how int) { - var f tcpip.ShutdownFlags - switch how { - case linux.SHUT_RD: - f = tcpip.ShutdownRead - case linux.SHUT_WR: - f = tcpip.ShutdownWrite - case linux.SHUT_RDWR: - f = tcpip.ShutdownWrite | tcpip.ShutdownRead - } - - // Atomically update the flags. - for { - old := atomic.LoadInt32(&s.shState) - if atomic.CompareAndSwapInt32(&s.shState, old, old|int32(f)) { - break - } - } -} - -func (s *socketOperations) resetShutdownFlags() { - atomic.StoreInt32(&s.shState, 0) -} - -func (s *socketOperations) isShutRdSet() bool { - return atomic.LoadInt32(&s.shState)&int32(tcpip.ShutdownRead) != 0 -} - -func (s *socketOperations) isShutWrSet() bool { - return atomic.LoadInt32(&s.shState)&int32(tcpip.ShutdownWrite) != 0 -} - -// Release implements fs.FileOperations.Release. -func (s *socketOperations) Release() { - s.notifier.RemoveFD(s.fd) - - // We always need to close the FD. - _, _ = s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Close{&pb.CloseRequest{Fd: s.fd}}}, true /* ignoreResult */) -} - -// Readiness implements waiter.Waitable.Readiness. -func (s *socketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { - return s.notifier.NonBlockingPoll(s.fd, mask) -} - -// EventRegister implements waiter.Waitable.EventRegister. -func (s *socketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { - s.wq.EventRegister(e, mask) - s.notifier.UpdateFD(s.fd) -} - -// EventUnregister implements waiter.Waitable.EventUnregister. -func (s *socketOperations) EventUnregister(e *waiter.Entry) { - s.wq.EventUnregister(e) - s.notifier.UpdateFD(s.fd) -} - -func rpcRead(t *kernel.Task, req *pb.SyscallRequest_Read) (*pb.ReadResponse_Data, *syserr.Error) { - s := t.NetworkContext().(*Stack) - id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */) - <-c - - res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Read).Read.Result - if e, ok := res.(*pb.ReadResponse_ErrorNumber); ok { - return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber)) - } - - return res.(*pb.ReadResponse_Data), nil -} - -// Read implements fs.FileOperations.Read. -func (s *socketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, _ int64) (int64, error) { - req := &pb.SyscallRequest_Read{&pb.ReadRequest{ - Fd: s.fd, - Length: uint32(dst.NumBytes()), - }} - - res, se := rpcRead(ctx.(*kernel.Task), req) - if se == nil { - n, e := dst.CopyOut(ctx, res.Data) - return int64(n), e - } - - return 0, se.ToError() -} - -func rpcWrite(t *kernel.Task, req *pb.SyscallRequest_Write) (uint32, *syserr.Error) { - s := t.NetworkContext().(*Stack) - id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */) - <-c - - res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Write).Write.Result - if e, ok := res.(*pb.WriteResponse_ErrorNumber); ok { - return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber)) - } - - return res.(*pb.WriteResponse_Length).Length, nil -} - -// Write implements fs.FileOperations.Write. -func (s *socketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { - t := ctx.(*kernel.Task) - v := buffer.NewView(int(src.NumBytes())) - - // Copy all the data into the buffer. - if _, err := src.CopyIn(t, v); err != nil { - return 0, err - } - - n, err := rpcWrite(t, &pb.SyscallRequest_Write{&pb.WriteRequest{Fd: s.fd, Data: v}}) - if n > 0 && n < uint32(src.NumBytes()) { - // The FileOperations.Write interface expects us to return ErrWouldBlock in - // the event of a partial write. - return int64(n), syserror.ErrWouldBlock - } - return int64(n), err.ToError() -} - -func rpcConnect(t *kernel.Task, fd uint32, sockaddr []byte) *syserr.Error { - s := t.NetworkContext().(*Stack) - id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Connect{&pb.ConnectRequest{Fd: uint32(fd), Address: sockaddr}}}, false /* ignoreResult */) - <-c - - if e := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Connect).Connect.ErrorNumber; e != 0 { - return syserr.FromHost(syscall.Errno(e)) - } - return nil -} - -// Connect implements socket.Socket.Connect. -func (s *socketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { - if !blocking { - e := rpcConnect(t, s.fd, sockaddr) - if e == nil { - // Reset the shutdown state on new connects. - s.resetShutdownFlags() - } - return e - } - - // Register for notification when the endpoint becomes writable, then - // initiate the connection. - e, ch := waiter.NewChannelEntry(nil) - s.EventRegister(&e, waiter.EventOut|waiter.EventIn|waiter.EventHUp) - defer s.EventUnregister(&e) - for { - if err := rpcConnect(t, s.fd, sockaddr); err == nil || err != syserr.ErrInProgress && err != syserr.ErrAlreadyInProgress { - if err == nil { - // Reset the shutdown state on new connects. - s.resetShutdownFlags() - } - return err - } - - // It's pending, so we have to wait for a notification, and fetch the - // result once the wait completes. - if err := t.Block(ch); err != nil { - return syserr.FromError(err) - } - } -} - -func rpcAccept(t *kernel.Task, fd uint32, peer bool) (*pb.AcceptResponse_ResultPayload, *syserr.Error) { - stack := t.NetworkContext().(*Stack) - id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Accept{&pb.AcceptRequest{Fd: fd, Peer: peer, Flags: syscall.SOCK_NONBLOCK}}}, false /* ignoreResult */) - <-c - - res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Accept).Accept.Result - if e, ok := res.(*pb.AcceptResponse_ErrorNumber); ok { - return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber)) - } - return res.(*pb.AcceptResponse_Payload).Payload, nil -} - -// Accept implements socket.Socket.Accept. -func (s *socketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { - payload, se := rpcAccept(t, s.fd, peerRequested) - - // Check if we need to block. - if blocking && se == syserr.ErrTryAgain { - // Register for notifications. - e, ch := waiter.NewChannelEntry(nil) - // FIXME(b/119878986): This waiter.EventHUp is a partial - // measure, need to figure out how to translate linux events to - // internal events. - s.EventRegister(&e, waiter.EventIn|waiter.EventHUp) - defer s.EventUnregister(&e) - - // Try to accept the connection again; if it fails, then wait until we - // get a notification. - for { - if payload, se = rpcAccept(t, s.fd, peerRequested); se != syserr.ErrTryAgain { - break - } - - if err := t.Block(ch); err != nil { - return 0, nil, 0, syserr.FromError(err) - } - } - } - - // Handle any error from accept. - if se != nil { - return 0, nil, 0, se - } - - var wq waiter.Queue - s.notifier.AddFD(payload.Fd, &wq) - - dirent := socket.NewDirent(t, socketDevice) - defer dirent.DecRef() - fileFlags := fs.FileFlags{ - Read: true, - Write: true, - NonSeekable: true, - NonBlocking: flags&linux.SOCK_NONBLOCK != 0, - } - file := fs.NewFile(t, dirent, fileFlags, &socketOperations{ - family: s.family, - stype: s.stype, - protocol: s.protocol, - wq: &wq, - fd: payload.Fd, - rpcConn: s.rpcConn, - notifier: s.notifier, - }) - defer file.DecRef() - - fd, err := t.NewFDFrom(0, file, kernel.FDFlags{ - CloseOnExec: flags&linux.SOCK_CLOEXEC != 0, - }) - if err != nil { - return 0, nil, 0, syserr.FromError(err) - } - t.Kernel().RecordSocket(file) - - if peerRequested { - return fd, socket.UnmarshalSockAddr(s.family, payload.Address.Address), payload.Address.Length, nil - } - - return fd, nil, 0, nil -} - -// Bind implements socket.Socket.Bind. -func (s *socketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { - stack := t.NetworkContext().(*Stack) - id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Bind{&pb.BindRequest{Fd: s.fd, Address: sockaddr}}}, false /* ignoreResult */) - <-c - - if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Bind).Bind.ErrorNumber; e != 0 { - return syserr.FromHost(syscall.Errno(e)) - } - return nil -} - -// Listen implements socket.Socket.Listen. -func (s *socketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error { - stack := t.NetworkContext().(*Stack) - id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Listen{&pb.ListenRequest{Fd: s.fd, Backlog: int64(backlog)}}}, false /* ignoreResult */) - <-c - - if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Listen).Listen.ErrorNumber; e != 0 { - return syserr.FromHost(syscall.Errno(e)) - } - return nil -} - -// Shutdown implements socket.Socket.Shutdown. -func (s *socketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { - // We save the shutdown state because of strange differences on linux - // related to recvs on blocking vs. non-blocking sockets after a SHUT_RD. - // We need to emulate that behavior on the blocking side. - // TODO(b/120096741): There is a possible race that can exist with loopback, - // where data could possibly be lost. - s.setShutdownFlags(how) - - stack := t.NetworkContext().(*Stack) - id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Shutdown{&pb.ShutdownRequest{Fd: s.fd, How: int64(how)}}}, false /* ignoreResult */) - <-c - - if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Shutdown).Shutdown.ErrorNumber; e != 0 { - return syserr.FromHost(syscall.Errno(e)) - } - - return nil -} - -// GetSockOpt implements socket.Socket.GetSockOpt. -func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { - // SO_RCVTIMEO and SO_SNDTIMEO are special because blocking is performed - // within the sentry. - if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO { - if outLen < linux.SizeOfTimeval { - return nil, syserr.ErrInvalidArgument - } - - return linux.NsecToTimeval(s.RecvTimeout()), nil - } - if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO { - if outLen < linux.SizeOfTimeval { - return nil, syserr.ErrInvalidArgument - } - - return linux.NsecToTimeval(s.SendTimeout()), nil - } - - stack := t.NetworkContext().(*Stack) - id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockOpt{&pb.GetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Length: uint32(outLen)}}}, false /* ignoreResult */) - <-c - - res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetSockOpt).GetSockOpt.Result - if e, ok := res.(*pb.GetSockOptResponse_ErrorNumber); ok { - return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber)) - } - - return res.(*pb.GetSockOptResponse_Opt).Opt, nil -} - -// SetSockOpt implements socket.Socket.SetSockOpt. -func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error { - // Because blocking actually happens within the sentry we need to inspect - // this socket option to determine if it's a SO_RCVTIMEO or SO_SNDTIMEO, - // and if so, we will save it and use it as the deadline for recv(2) - // or send(2) related syscalls. - if level == linux.SOL_SOCKET && name == linux.SO_RCVTIMEO { - if len(opt) < linux.SizeOfTimeval { - return syserr.ErrInvalidArgument - } - - var v linux.Timeval - binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v) - if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { - return syserr.ErrDomain - } - s.SetRecvTimeout(v.ToNsecCapped()) - return nil - } - if level == linux.SOL_SOCKET && name == linux.SO_SNDTIMEO { - if len(opt) < linux.SizeOfTimeval { - return syserr.ErrInvalidArgument - } - - var v linux.Timeval - binary.Unmarshal(opt[:linux.SizeOfTimeval], usermem.ByteOrder, &v) - if v.Usec < 0 || v.Usec >= int64(time.Second/time.Microsecond) { - return syserr.ErrDomain - } - s.SetSendTimeout(v.ToNsecCapped()) - return nil - } - - stack := t.NetworkContext().(*Stack) - id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_SetSockOpt{&pb.SetSockOptRequest{Fd: s.fd, Level: int64(level), Name: int64(name), Opt: opt}}}, false /* ignoreResult */) - <-c - - if e := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_SetSockOpt).SetSockOpt.ErrorNumber; e != 0 { - return syserr.FromHost(syscall.Errno(e)) - } - return nil -} - -// GetPeerName implements socket.Socket.GetPeerName. -func (s *socketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { - stack := t.NetworkContext().(*Stack) - id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetPeerName{&pb.GetPeerNameRequest{Fd: s.fd}}}, false /* ignoreResult */) - <-c - - res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetPeerName).GetPeerName.Result - if e, ok := res.(*pb.GetPeerNameResponse_ErrorNumber); ok { - return nil, 0, syserr.FromHost(syscall.Errno(e.ErrorNumber)) - } - - addr := res.(*pb.GetPeerNameResponse_Address).Address - return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil -} - -// GetSockName implements socket.Socket.GetSockName. -func (s *socketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { - stack := t.NetworkContext().(*Stack) - id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_GetSockName{&pb.GetSockNameRequest{Fd: s.fd}}}, false /* ignoreResult */) - <-c - - res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_GetSockName).GetSockName.Result - if e, ok := res.(*pb.GetSockNameResponse_ErrorNumber); ok { - return nil, 0, syserr.FromHost(syscall.Errno(e.ErrorNumber)) - } - - addr := res.(*pb.GetSockNameResponse_Address).Address - return socket.UnmarshalSockAddr(s.family, addr.Address), addr.Length, nil -} - -func rpcIoctl(t *kernel.Task, fd, cmd uint32, arg []byte) ([]byte, error) { - stack := t.NetworkContext().(*Stack) - - id, c := stack.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Ioctl{&pb.IOCtlRequest{Fd: fd, Cmd: cmd, Arg: arg}}}, false /* ignoreResult */) - <-c - - res := stack.rpcConn.Request(id).Result.(*pb.SyscallResponse_Ioctl).Ioctl.Result - if e, ok := res.(*pb.IOCtlResponse_ErrorNumber); ok { - return nil, syscall.Errno(e.ErrorNumber) - } - - return res.(*pb.IOCtlResponse_Value).Value, nil -} - -// ifconfIoctlFromStack populates a struct ifconf for the SIOCGIFCONF ioctl. -func ifconfIoctlFromStack(ctx context.Context, io usermem.IO, ifc *linux.IFConf) error { - // If Ptr is NULL, return the necessary buffer size via Len. - // Otherwise, write up to Len bytes starting at Ptr containing ifreq - // structs. - t := ctx.(*kernel.Task) - s := t.NetworkContext().(*Stack) - if s == nil { - return syserr.ErrNoDevice.ToError() - } - - if ifc.Ptr == 0 { - ifc.Len = int32(len(s.Interfaces())) * int32(linux.SizeOfIFReq) - return nil - } - - max := ifc.Len - ifc.Len = 0 - for key, ifaceAddrs := range s.InterfaceAddrs() { - iface := s.Interfaces()[key] - for _, ifaceAddr := range ifaceAddrs { - // Don't write past the end of the buffer. - if ifc.Len+int32(linux.SizeOfIFReq) > max { - break - } - if ifaceAddr.Family != linux.AF_INET { - continue - } - - // Populate ifr.ifr_addr. - ifr := linux.IFReq{} - ifr.SetName(iface.Name) - usermem.ByteOrder.PutUint16(ifr.Data[0:2], uint16(ifaceAddr.Family)) - usermem.ByteOrder.PutUint16(ifr.Data[2:4], 0) - copy(ifr.Data[4:8], ifaceAddr.Addr[:4]) - - // Copy the ifr to userspace. - dst := uintptr(ifc.Ptr) + uintptr(ifc.Len) - ifc.Len += int32(linux.SizeOfIFReq) - if _, err := usermem.CopyObjectOut(ctx, io, usermem.Addr(dst), ifr, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { - return err - } - } - } - return nil -} - -// Ioctl implements fs.FileOperations.Ioctl. -func (s *socketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, args arch.SyscallArguments) (uintptr, error) { - t := ctx.(*kernel.Task) - - cmd := uint32(args[1].Int()) - arg := args[2].Pointer() - - var buf []byte - switch cmd { - // The following ioctls take 4 byte argument parameters. - case syscall.TIOCINQ, - syscall.TIOCOUTQ: - buf = make([]byte, 4) - // The following ioctls have args which are sizeof(struct ifreq). - case syscall.SIOCGIFADDR, - syscall.SIOCGIFBRDADDR, - syscall.SIOCGIFDSTADDR, - syscall.SIOCGIFFLAGS, - syscall.SIOCGIFHWADDR, - syscall.SIOCGIFINDEX, - syscall.SIOCGIFMAP, - syscall.SIOCGIFMETRIC, - syscall.SIOCGIFMTU, - syscall.SIOCGIFNAME, - syscall.SIOCGIFNETMASK, - syscall.SIOCGIFTXQLEN: - buf = make([]byte, linux.SizeOfIFReq) - case syscall.SIOCGIFCONF: - // SIOCGIFCONF has slightly different behavior than the others, in that it - // will need to populate the array of ifreqs. - var ifc linux.IFConf - if _, err := usermem.CopyObjectIn(ctx, io, args[2].Pointer(), &ifc, usermem.IOOpts{ - AddressSpaceActive: true, - }); err != nil { - return 0, err - } - - if err := ifconfIoctlFromStack(ctx, io, &ifc); err != nil { - return 0, err - } - _, err := usermem.CopyObjectOut(ctx, io, args[2].Pointer(), ifc, usermem.IOOpts{ - AddressSpaceActive: true, - }) - - return 0, err - - case linux.SIOCGIFMEM, linux.SIOCGIFPFLAGS, linux.SIOCGMIIPHY, linux.SIOCGMIIREG: - unimpl.EmitUnimplementedEvent(ctx) - - default: - return 0, syserror.ENOTTY - } - - _, err := io.CopyIn(ctx, arg, buf, usermem.IOOpts{ - AddressSpaceActive: true, - }) - - if err != nil { - return 0, err - } - - v, err := rpcIoctl(t, s.fd, cmd, buf) - if err != nil { - return 0, err - } - - if len(v) != len(buf) { - return 0, syserror.EINVAL - } - - _, err = io.CopyOut(ctx, arg, v, usermem.IOOpts{ - AddressSpaceActive: true, - }) - return 0, err -} - -func rpcRecvMsg(t *kernel.Task, req *pb.SyscallRequest_Recvmsg) (*pb.RecvmsgResponse_ResultPayload, *syserr.Error) { - s := t.NetworkContext().(*Stack) - id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */) - <-c - - res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Recvmsg).Recvmsg.Result - if e, ok := res.(*pb.RecvmsgResponse_ErrorNumber); ok { - return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber)) - } - - return res.(*pb.RecvmsgResponse_Payload).Payload, nil -} - -// Because we only support SO_TIMESTAMP we will search control messages for -// that value and set it if so, all other control messages will be ignored. -func (s *socketOperations) extractControlMessages(payload *pb.RecvmsgResponse_ResultPayload) socket.ControlMessages { - c := socket.ControlMessages{} - if len(payload.GetCmsgData()) > 0 { - // Parse the control messages looking for SO_TIMESTAMP. - msgs, e := syscall.ParseSocketControlMessage(payload.GetCmsgData()) - if e != nil { - return socket.ControlMessages{} - } - for _, m := range msgs { - if m.Header.Level != linux.SOL_SOCKET || m.Header.Type != linux.SO_TIMESTAMP { - continue - } - - // Let's parse the time stamp and set it. - if len(m.Data) < linux.SizeOfTimeval { - // Give up on locating the SO_TIMESTAMP option. - return socket.ControlMessages{} - } - - var v linux.Timeval - binary.Unmarshal(m.Data[:linux.SizeOfTimeval], usermem.ByteOrder, &v) - c.IP.HasTimestamp = true - c.IP.Timestamp = v.ToNsecCapped() - break - } - } - return c -} - -// RecvMsg implements socket.Socket.RecvMsg. -func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (int, int, linux.SockAddr, uint32, socket.ControlMessages, *syserr.Error) { - req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{ - Fd: s.fd, - Length: uint32(dst.NumBytes()), - Sender: senderRequested, - Trunc: flags&linux.MSG_TRUNC != 0, - Peek: flags&linux.MSG_PEEK != 0, - CmsgLength: uint32(controlDataLen), - }} - - res, err := rpcRecvMsg(t, req) - if err == nil { - var e error - var n int - if len(res.Data) > 0 { - n, e = dst.CopyOut(t, res.Data) - if e == nil && n != len(res.Data) { - panic("CopyOut failed to copy full buffer") - } - } - c := s.extractControlMessages(res) - return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e) - } - if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 { - return 0, 0, nil, 0, socket.ControlMessages{}, err - } - - // We'll have to block. Register for notifications and keep trying to - // send all the data. - e, ch := waiter.NewChannelEntry(nil) - s.EventRegister(&e, waiter.EventIn) - defer s.EventUnregister(&e) - - for { - res, err := rpcRecvMsg(t, req) - if err == nil { - var e error - var n int - if len(res.Data) > 0 { - n, e = dst.CopyOut(t, res.Data) - if e == nil && n != len(res.Data) { - panic("CopyOut failed to copy full buffer") - } - } - c := s.extractControlMessages(res) - return int(res.Length), 0, socket.UnmarshalSockAddr(s.family, res.Address.GetAddress()), res.Address.GetLength(), c, syserr.FromError(e) - } - if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain { - return 0, 0, nil, 0, socket.ControlMessages{}, err - } - - if s.isShutRdSet() { - // Blocking would have caused us to block indefinitely so we return 0, - // this is the same behavior as Linux. - return 0, 0, nil, 0, socket.ControlMessages{}, nil - } - - if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { - return 0, 0, nil, 0, socket.ControlMessages{}, syserr.ErrTryAgain - } - return 0, 0, nil, 0, socket.ControlMessages{}, syserr.FromError(err) - } - } -} - -func rpcSendMsg(t *kernel.Task, req *pb.SyscallRequest_Sendmsg) (uint32, *syserr.Error) { - s := t.NetworkContext().(*Stack) - id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */) - <-c - - res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Sendmsg).Sendmsg.Result - if e, ok := res.(*pb.SendmsgResponse_ErrorNumber); ok { - return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber)) - } - - return res.(*pb.SendmsgResponse_Length).Length, nil -} - -// SendMsg implements socket.Socket.SendMsg. -func (s *socketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { - // Whitelist flags. - if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 { - return 0, syserr.ErrInvalidArgument - } - - // Reject Unix control messages. - if !controlMessages.Unix.Empty() { - return 0, syserr.ErrInvalidArgument - } - - v := buffer.NewView(int(src.NumBytes())) - - // Copy all the data into the buffer. - if _, err := src.CopyIn(t, v); err != nil { - return 0, syserr.FromError(err) - } - - // TODO(bgeffon): this needs to change to map directly to a SendMsg syscall - // in the RPC. - totalWritten := 0 - n, err := rpcSendMsg(t, &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{ - Fd: uint32(s.fd), - Data: v, - Address: to, - More: flags&linux.MSG_MORE != 0, - EndOfRecord: flags&linux.MSG_EOR != 0, - }}) - - if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain || flags&linux.MSG_DONTWAIT != 0 { - return int(n), err - } - - if n > 0 { - totalWritten += int(n) - v.TrimFront(int(n)) - } - - // We'll have to block. Register for notification and keep trying to - // send all the data. - e, ch := waiter.NewChannelEntry(nil) - s.EventRegister(&e, waiter.EventOut) - defer s.EventUnregister(&e) - - for { - n, err := rpcSendMsg(t, &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{ - Fd: uint32(s.fd), - Data: v, - Address: to, - More: flags&linux.MSG_MORE != 0, - EndOfRecord: flags&linux.MSG_EOR != 0, - }}) - - if n > 0 { - totalWritten += int(n) - v.TrimFront(int(n)) - - if err == nil && totalWritten < int(src.NumBytes()) { - continue - } - } - - if err != syserr.ErrWouldBlock && err != syserr.ErrTryAgain { - // We eat the error in this situation. - return int(totalWritten), nil - } - - if err := t.BlockWithDeadline(ch, haveDeadline, deadline); err != nil { - if err == syserror.ETIMEDOUT { - return int(totalWritten), syserr.ErrTryAgain - } - return int(totalWritten), syserr.FromError(err) - } - } -} - -// State implements socket.Socket.State. -func (s *socketOperations) State() uint32 { - // TODO(b/127845868): Define a new rpc to query the socket state. - return 0 -} - -// Type implements socket.Socket.Type. -func (s *socketOperations) Type() (family int, skType linux.SockType, protocol int) { - return s.family, s.stype, s.protocol -} - -type socketProvider struct { - family int -} - -// Socket implements socket.Provider.Socket. -func (p *socketProvider) Socket(t *kernel.Task, stypeflags linux.SockType, protocol int) (*fs.File, *syserr.Error) { - // Check that we are using the RPC network stack. - stack := t.NetworkContext() - if stack == nil { - return nil, nil - } - - s, ok := stack.(*Stack) - if !ok { - return nil, nil - } - - // Only accept TCP and UDP. - // - // Try to restrict the flags we will accept to minimize backwards - // incompatibility with netstack. - stype := stypeflags & linux.SOCK_TYPE_MASK - switch stype { - case syscall.SOCK_STREAM: - switch protocol { - case 0, syscall.IPPROTO_TCP: - // ok - default: - return nil, nil - } - case syscall.SOCK_DGRAM: - switch protocol { - case 0, syscall.IPPROTO_UDP: - // ok - default: - return nil, nil - } - default: - return nil, nil - } - - return newSocketFile(t, s, p.family, stype, protocol) -} - -// Pair implements socket.Provider.Pair. -func (p *socketProvider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.File, *fs.File, *syserr.Error) { - // Not supported by AF_INET/AF_INET6. - return nil, nil, nil -} - -func init() { - for _, family := range []int{syscall.AF_INET, syscall.AF_INET6} { - socket.RegisterProvider(family, &socketProvider{family}) - } -} diff --git a/pkg/sentry/socket/rpcinet/stack.go b/pkg/sentry/socket/rpcinet/stack.go deleted file mode 100644 index f5441b826..000000000 --- a/pkg/sentry/socket/rpcinet/stack.go +++ /dev/null @@ -1,178 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package rpcinet - -import ( - "fmt" - "syscall" - - "gvisor.dev/gvisor/pkg/sentry/inet" - "gvisor.dev/gvisor/pkg/sentry/socket/hostinet" - "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/conn" - "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/notifier" - "gvisor.dev/gvisor/pkg/syserr" - "gvisor.dev/gvisor/pkg/tcpip" - "gvisor.dev/gvisor/pkg/unet" -) - -// Stack implements inet.Stack for RPC backed sockets. -type Stack struct { - interfaces map[int32]inet.Interface - interfaceAddrs map[int32][]inet.InterfaceAddr - routes []inet.Route - rpcConn *conn.RPCConnection - notifier *notifier.Notifier -} - -// NewStack returns a Stack containing the current state of the host network -// stack. -func NewStack(fd int32) (*Stack, error) { - sock, err := unet.NewSocket(int(fd)) - if err != nil { - return nil, err - } - - stack := &Stack{ - interfaces: make(map[int32]inet.Interface), - interfaceAddrs: make(map[int32][]inet.InterfaceAddr), - rpcConn: conn.NewRPCConnection(sock), - } - - var e error - stack.notifier, e = notifier.NewRPCNotifier(stack.rpcConn) - if e != nil { - return nil, e - } - - links, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETLINK) - if err != nil { - return nil, fmt.Errorf("RTM_GETLINK failed: %v", err) - } - - addrs, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETADDR) - if err != nil { - return nil, fmt.Errorf("RTM_GETADDR failed: %v", err) - } - - e = hostinet.ExtractHostInterfaces(links, addrs, stack.interfaces, stack.interfaceAddrs) - if e != nil { - return nil, e - } - - routes, err := stack.DoNetlinkRouteRequest(syscall.RTM_GETROUTE) - if err != nil { - return nil, fmt.Errorf("RTM_GETROUTE failed: %v", err) - } - - stack.routes, e = hostinet.ExtractHostRoutes(routes) - if e != nil { - return nil, e - } - - return stack, nil -} - -// RPCReadFile will execute the ReadFile helper RPC method which avoids the -// common pattern of open(2), read(2), close(2) by doing all three operations -// as a single RPC. It will read the entire file or return EFBIG if the file -// was too large. -func (s *Stack) RPCReadFile(path string) ([]byte, *syserr.Error) { - return s.rpcConn.RPCReadFile(path) -} - -// RPCWriteFile will execute the WriteFile helper RPC method which avoids the -// common pattern of open(2), write(2), write(2), close(2) by doing all -// operations as a single RPC. -func (s *Stack) RPCWriteFile(path string, data []byte) (int64, *syserr.Error) { - return s.rpcConn.RPCWriteFile(path, data) -} - -// Interfaces implements inet.Stack.Interfaces. -func (s *Stack) Interfaces() map[int32]inet.Interface { - interfaces := make(map[int32]inet.Interface) - for k, v := range s.interfaces { - interfaces[k] = v - } - return interfaces -} - -// InterfaceAddrs implements inet.Stack.InterfaceAddrs. -func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { - addrs := make(map[int32][]inet.InterfaceAddr) - for k, v := range s.interfaceAddrs { - addrs[k] = append([]inet.InterfaceAddr(nil), v...) - } - return addrs -} - -// SupportsIPv6 implements inet.Stack.SupportsIPv6. -func (s *Stack) SupportsIPv6() bool { - panic("rpcinet handles procfs directly this method should not be called") -} - -// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize. -func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { - panic("rpcinet handles procfs directly this method should not be called") -} - -// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize. -func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error { - panic("rpcinet handles procfs directly this method should not be called") - -} - -// TCPSendBufferSize implements inet.Stack.TCPSendBufferSize. -func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { - panic("rpcinet handles procfs directly this method should not be called") - -} - -// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize. -func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error { - panic("rpcinet handles procfs directly this method should not be called") -} - -// TCPSACKEnabled implements inet.Stack.TCPSACKEnabled. -func (s *Stack) TCPSACKEnabled() (bool, error) { - panic("rpcinet handles procfs directly this method should not be called") -} - -// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. -func (s *Stack) SetTCPSACKEnabled(enabled bool) error { - panic("rpcinet handles procfs directly this method should not be called") -} - -// Statistics implements inet.Stack.Statistics. -func (s *Stack) Statistics(stat interface{}, arg string) error { - return syserr.ErrEndpointOperation.ToError() -} - -// RouteTable implements inet.Stack.RouteTable. -func (s *Stack) RouteTable() []inet.Route { - return append([]inet.Route(nil), s.routes...) -} - -// Resume implements inet.Stack.Resume. -func (s *Stack) Resume() {} - -// Forwarding implements inet.Stack.Forwarding. -func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { - panic("rpcinet handles procfs directly this method should not be called") -} - -// SetForwarding implements inet.Stack.SetForwarding. -func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { - panic("rpcinet handles procfs directly this method should not be called") -} diff --git a/pkg/sentry/socket/rpcinet/stack_unsafe.go b/pkg/sentry/socket/rpcinet/stack_unsafe.go deleted file mode 100644 index a94bdad83..000000000 --- a/pkg/sentry/socket/rpcinet/stack_unsafe.go +++ /dev/null @@ -1,193 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package rpcinet - -import ( - "syscall" - "unsafe" - - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/binary" - pb "gvisor.dev/gvisor/pkg/sentry/socket/rpcinet/syscall_rpc_go_proto" - "gvisor.dev/gvisor/pkg/sentry/usermem" - "gvisor.dev/gvisor/pkg/syserr" -) - -// NewNetlinkRouteRequest builds a netlink message for getting the RIB, -// the routing information base. -func newNetlinkRouteRequest(proto, seq, family int) []byte { - rr := &syscall.NetlinkRouteRequest{} - rr.Header.Len = uint32(syscall.NLMSG_HDRLEN + syscall.SizeofRtGenmsg) - rr.Header.Type = uint16(proto) - rr.Header.Flags = syscall.NLM_F_DUMP | syscall.NLM_F_REQUEST - rr.Header.Seq = uint32(seq) - rr.Data.Family = uint8(family) - return netlinkRRtoWireFormat(rr) -} - -func netlinkRRtoWireFormat(rr *syscall.NetlinkRouteRequest) []byte { - b := make([]byte, rr.Header.Len) - *(*uint32)(unsafe.Pointer(&b[0:4][0])) = rr.Header.Len - *(*uint16)(unsafe.Pointer(&b[4:6][0])) = rr.Header.Type - *(*uint16)(unsafe.Pointer(&b[6:8][0])) = rr.Header.Flags - *(*uint32)(unsafe.Pointer(&b[8:12][0])) = rr.Header.Seq - *(*uint32)(unsafe.Pointer(&b[12:16][0])) = rr.Header.Pid - b[16] = byte(rr.Data.Family) - return b -} - -func (s *Stack) getNetlinkFd() (uint32, *syserr.Error) { - id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Socket{&pb.SocketRequest{Family: int64(syscall.AF_NETLINK), Type: int64(syscall.SOCK_RAW | syscall.SOCK_NONBLOCK), Protocol: int64(syscall.NETLINK_ROUTE)}}}, false /* ignoreResult */) - <-c - - res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Socket).Socket.Result - if e, ok := res.(*pb.SocketResponse_ErrorNumber); ok { - return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber)) - } - return res.(*pb.SocketResponse_Fd).Fd, nil -} - -func (s *Stack) bindNetlinkFd(fd uint32, sockaddr []byte) *syserr.Error { - id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Bind{&pb.BindRequest{Fd: fd, Address: sockaddr}}}, false /* ignoreResult */) - <-c - - if e := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Bind).Bind.ErrorNumber; e != 0 { - return syserr.FromHost(syscall.Errno(e)) - } - return nil -} - -func (s *Stack) closeNetlinkFd(fd uint32) { - _, _ = s.rpcConn.NewRequest(pb.SyscallRequest{Args: &pb.SyscallRequest_Close{&pb.CloseRequest{Fd: fd}}}, true /* ignoreResult */) -} - -func (s *Stack) rpcSendMsg(req *pb.SyscallRequest_Sendmsg) (uint32, *syserr.Error) { - id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */) - <-c - - res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Sendmsg).Sendmsg.Result - if e, ok := res.(*pb.SendmsgResponse_ErrorNumber); ok { - return 0, syserr.FromHost(syscall.Errno(e.ErrorNumber)) - } - - return res.(*pb.SendmsgResponse_Length).Length, nil -} - -func (s *Stack) sendMsg(fd uint32, buf []byte, to []byte, flags int) (int, *syserr.Error) { - // Whitelist flags. - if flags&^(syscall.MSG_DONTWAIT|syscall.MSG_EOR|syscall.MSG_FASTOPEN|syscall.MSG_MORE|syscall.MSG_NOSIGNAL) != 0 { - return 0, syserr.ErrInvalidArgument - } - - req := &pb.SyscallRequest_Sendmsg{&pb.SendmsgRequest{ - Fd: fd, - Data: buf, - Address: to, - More: flags&linux.MSG_MORE != 0, - EndOfRecord: flags&linux.MSG_EOR != 0, - }} - - n, err := s.rpcSendMsg(req) - return int(n), err -} - -func (s *Stack) rpcRecvMsg(req *pb.SyscallRequest_Recvmsg) (*pb.RecvmsgResponse_ResultPayload, *syserr.Error) { - id, c := s.rpcConn.NewRequest(pb.SyscallRequest{Args: req}, false /* ignoreResult */) - <-c - - res := s.rpcConn.Request(id).Result.(*pb.SyscallResponse_Recvmsg).Recvmsg.Result - if e, ok := res.(*pb.RecvmsgResponse_ErrorNumber); ok { - return nil, syserr.FromHost(syscall.Errno(e.ErrorNumber)) - } - - return res.(*pb.RecvmsgResponse_Payload).Payload, nil -} - -func (s *Stack) recvMsg(fd, l, flags uint32) ([]byte, *syserr.Error) { - req := &pb.SyscallRequest_Recvmsg{&pb.RecvmsgRequest{ - Fd: fd, - Length: l, - Sender: false, - Trunc: flags&linux.MSG_TRUNC != 0, - Peek: flags&linux.MSG_PEEK != 0, - }} - - res, err := s.rpcRecvMsg(req) - if err != nil { - return nil, err - } - return res.Data, nil -} - -func (s *Stack) netlinkRequest(proto, family int) ([]byte, error) { - fd, err := s.getNetlinkFd() - if err != nil { - return nil, err.ToError() - } - defer s.closeNetlinkFd(fd) - - lsa := syscall.SockaddrNetlink{Family: syscall.AF_NETLINK} - b := binary.Marshal(nil, usermem.ByteOrder, &lsa) - if err := s.bindNetlinkFd(fd, b); err != nil { - return nil, err.ToError() - } - - wb := newNetlinkRouteRequest(proto, 1, family) - _, err = s.sendMsg(fd, wb, b, 0) - if err != nil { - return nil, err.ToError() - } - - var tab []byte -done: - for { - rb, err := s.recvMsg(fd, uint32(syscall.Getpagesize()), 0) - nr := len(rb) - if err != nil { - return nil, err.ToError() - } - - if nr < syscall.NLMSG_HDRLEN { - return nil, syserr.ErrInvalidArgument.ToError() - } - - tab = append(tab, rb...) - msgs, e := syscall.ParseNetlinkMessage(rb) - if e != nil { - return nil, e - } - - for _, m := range msgs { - if m.Header.Type == syscall.NLMSG_DONE { - break done - } - if m.Header.Type == syscall.NLMSG_ERROR { - return nil, syserr.ErrInvalidArgument.ToError() - } - } - } - - return tab, nil -} - -// DoNetlinkRouteRequest returns routing information base, also known as RIB, -// which consists of network facility information, states and parameters. -func (s *Stack) DoNetlinkRouteRequest(req int) ([]syscall.NetlinkMessage, error) { - data, err := s.netlinkRequest(req, syscall.AF_UNSPEC) - if err != nil { - return nil, err - } - return syscall.ParseNetlinkMessage(data) -} diff --git a/pkg/sentry/socket/rpcinet/syscall_rpc.proto b/pkg/sentry/socket/rpcinet/syscall_rpc.proto deleted file mode 100644 index 9586f5923..000000000 --- a/pkg/sentry/socket/rpcinet/syscall_rpc.proto +++ /dev/null @@ -1,353 +0,0 @@ -syntax = "proto3"; - -// package syscall_rpc is a set of networking related system calls that can be -// forwarded to a socket gofer. -// -// TODO(b/77963526): Document individual RPCs. -package syscall_rpc; - -message SendmsgRequest { - uint32 fd = 1; - bytes data = 2 [ctype = CORD]; - bytes address = 3; - bool more = 4; - bool end_of_record = 5; -} - -message SendmsgResponse { - oneof result { - uint32 error_number = 1; - uint32 length = 2; - } -} - -message IOCtlRequest { - uint32 fd = 1; - uint32 cmd = 2; - bytes arg = 3; -} - -message IOCtlResponse { - oneof result { - uint32 error_number = 1; - bytes value = 2; - } -} - -message RecvmsgRequest { - uint32 fd = 1; - uint32 length = 2; - bool sender = 3; - bool peek = 4; - bool trunc = 5; - uint32 cmsg_length = 6; -} - -message OpenRequest { - bytes path = 1; - uint32 flags = 2; - uint32 mode = 3; -} - -message OpenResponse { - oneof result { - uint32 error_number = 1; - uint32 fd = 2; - } -} - -message ReadRequest { - uint32 fd = 1; - uint32 length = 2; -} - -message ReadResponse { - oneof result { - uint32 error_number = 1; - bytes data = 2 [ctype = CORD]; - } -} - -message ReadFileRequest { - string path = 1; -} - -message ReadFileResponse { - oneof result { - uint32 error_number = 1; - bytes data = 2 [ctype = CORD]; - } -} - -message WriteRequest { - uint32 fd = 1; - bytes data = 2 [ctype = CORD]; -} - -message WriteResponse { - oneof result { - uint32 error_number = 1; - uint32 length = 2; - } -} - -message WriteFileRequest { - string path = 1; - bytes content = 2; -} - -message WriteFileResponse { - uint32 error_number = 1; - uint32 written = 2; -} - -message AddressResponse { - bytes address = 1; - uint32 length = 2; -} - -message RecvmsgResponse { - message ResultPayload { - bytes data = 1 [ctype = CORD]; - AddressResponse address = 2; - uint32 length = 3; - bytes cmsg_data = 4; - } - oneof result { - uint32 error_number = 1; - ResultPayload payload = 2; - } -} - -message BindRequest { - uint32 fd = 1; - bytes address = 2; -} - -message BindResponse { - uint32 error_number = 1; -} - -message AcceptRequest { - uint32 fd = 1; - bool peer = 2; - int64 flags = 3; -} - -message AcceptResponse { - message ResultPayload { - uint32 fd = 1; - AddressResponse address = 2; - } - oneof result { - uint32 error_number = 1; - ResultPayload payload = 2; - } -} - -message ConnectRequest { - uint32 fd = 1; - bytes address = 2; -} - -message ConnectResponse { - uint32 error_number = 1; -} - -message ListenRequest { - uint32 fd = 1; - int64 backlog = 2; -} - -message ListenResponse { - uint32 error_number = 1; -} - -message ShutdownRequest { - uint32 fd = 1; - int64 how = 2; -} - -message ShutdownResponse { - uint32 error_number = 1; -} - -message CloseRequest { - uint32 fd = 1; -} - -message CloseResponse { - uint32 error_number = 1; -} - -message GetSockOptRequest { - uint32 fd = 1; - int64 level = 2; - int64 name = 3; - uint32 length = 4; -} - -message GetSockOptResponse { - oneof result { - uint32 error_number = 1; - bytes opt = 2; - } -} - -message SetSockOptRequest { - uint32 fd = 1; - int64 level = 2; - int64 name = 3; - bytes opt = 4; -} - -message SetSockOptResponse { - uint32 error_number = 1; -} - -message GetSockNameRequest { - uint32 fd = 1; -} - -message GetSockNameResponse { - oneof result { - uint32 error_number = 1; - AddressResponse address = 2; - } -} - -message GetPeerNameRequest { - uint32 fd = 1; -} - -message GetPeerNameResponse { - oneof result { - uint32 error_number = 1; - AddressResponse address = 2; - } -} - -message SocketRequest { - int64 family = 1; - int64 type = 2; - int64 protocol = 3; -} - -message SocketResponse { - oneof result { - uint32 error_number = 1; - uint32 fd = 2; - } -} - -message EpollWaitRequest { - uint32 fd = 1; - uint32 num_events = 2; - sint64 msec = 3; -} - -message EpollEvent { - uint32 fd = 1; - uint32 events = 2; -} - -message EpollEvents { - repeated EpollEvent events = 1; -} - -message EpollWaitResponse { - oneof result { - uint32 error_number = 1; - EpollEvents events = 2; - } -} - -message EpollCtlRequest { - uint32 epfd = 1; - int64 op = 2; - uint32 fd = 3; - EpollEvent event = 4; -} - -message EpollCtlResponse { - uint32 error_number = 1; -} - -message EpollCreate1Request { - int64 flag = 1; -} - -message EpollCreate1Response { - oneof result { - uint32 error_number = 1; - uint32 fd = 2; - } -} - -message PollRequest { - uint32 fd = 1; - uint32 events = 2; -} - -message PollResponse { - oneof result { - uint32 error_number = 1; - uint32 events = 2; - } -} - -message SyscallRequest { - oneof args { - SocketRequest socket = 1; - SendmsgRequest sendmsg = 2; - RecvmsgRequest recvmsg = 3; - BindRequest bind = 4; - AcceptRequest accept = 5; - ConnectRequest connect = 6; - ListenRequest listen = 7; - ShutdownRequest shutdown = 8; - CloseRequest close = 9; - GetSockOptRequest get_sock_opt = 10; - SetSockOptRequest set_sock_opt = 11; - GetSockNameRequest get_sock_name = 12; - GetPeerNameRequest get_peer_name = 13; - EpollWaitRequest epoll_wait = 14; - EpollCtlRequest epoll_ctl = 15; - EpollCreate1Request epoll_create1 = 16; - PollRequest poll = 17; - ReadRequest read = 18; - WriteRequest write = 19; - OpenRequest open = 20; - IOCtlRequest ioctl = 21; - WriteFileRequest write_file = 22; - ReadFileRequest read_file = 23; - } -} - -message SyscallResponse { - oneof result { - SocketResponse socket = 1; - SendmsgResponse sendmsg = 2; - RecvmsgResponse recvmsg = 3; - BindResponse bind = 4; - AcceptResponse accept = 5; - ConnectResponse connect = 6; - ListenResponse listen = 7; - ShutdownResponse shutdown = 8; - CloseResponse close = 9; - GetSockOptResponse get_sock_opt = 10; - SetSockOptResponse set_sock_opt = 11; - GetSockNameResponse get_sock_name = 12; - GetPeerNameResponse get_peer_name = 13; - EpollWaitResponse epoll_wait = 14; - EpollCtlResponse epoll_ctl = 15; - EpollCreate1Response epoll_create1 = 16; - PollResponse poll = 17; - ReadResponse read = 18; - WriteResponse write = 19; - OpenResponse open = 20; - IOCtlResponse ioctl = 21; - WriteFileResponse write_file = 22; - ReadFileResponse read_file = 23; - } -} diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go index 8c250c325..04b259d27 100644 --- a/pkg/sentry/socket/socket.go +++ b/pkg/sentry/socket/socket.go @@ -24,16 +24,18 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/binary" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/device" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // ControlMessages represents the union of unix control messages and tcpip @@ -43,11 +45,30 @@ type ControlMessages struct { IP tcpip.ControlMessages } -// Socket is the interface containing socket syscalls used by the syscall layer -// to redirect them to the appropriate implementation. +// Release releases Unix domain socket credentials and rights. +func (c *ControlMessages) Release(ctx context.Context) { + c.Unix.Release(ctx) +} + +// Socket is an interface combining fs.FileOperations and SocketOps, +// representing a VFS1 socket file. type Socket interface { fs.FileOperations + SocketOps +} +// SocketVFS2 is an interface combining vfs.FileDescription and SocketOps, +// representing a VFS2 socket file. +type SocketVFS2 interface { + vfs.FileDescriptionImpl + SocketOps +} + +// SocketOps is the interface containing socket syscalls used by the syscall +// layer to redirect them to the appropriate implementation. +// +// It is implemented by both Socket and SocketVFS2. +type SocketOps interface { // Connect implements the connect(2) linux syscall. Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error @@ -66,7 +87,7 @@ type Socket interface { Shutdown(t *kernel.Task, how int) *syserr.Error // GetSockOpt implements the getsockopt(2) linux syscall. - GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) + GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) // SetSockOpt implements the setsockopt(2) linux syscall. SetSockOpt(t *kernel.Task, level int, name int, opt []byte) *syserr.Error @@ -148,6 +169,8 @@ var families = make(map[int][]Provider) // RegisterProvider registers the provider of a given address family so that // sockets of that type can be created via socket() and/or socketpair() // syscalls. +// +// This should only be called during the initialization of the address family. func RegisterProvider(family int, provider Provider) { families[family] = append(families[family], provider) } @@ -211,6 +234,74 @@ func NewDirent(ctx context.Context, d *device.Device) *fs.Dirent { return fs.NewDirent(ctx, inode, fmt.Sprintf("socket:[%d]", ino)) } +// ProviderVFS2 is the vfs2 interface implemented by providers of sockets for +// specific address families (e.g., AF_INET). +type ProviderVFS2 interface { + // Socket creates a new socket. + // + // If a nil Socket _and_ a nil error is returned, it means that the + // protocol is not supported. A non-nil error should only be returned + // if the protocol is supported, but an error occurs during creation. + Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) + + // Pair creates a pair of connected sockets. + // + // See Socket for error information. + Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) +} + +// familiesVFS2 holds a map of all known address families and their providers. +var familiesVFS2 = make(map[int][]ProviderVFS2) + +// RegisterProviderVFS2 registers the provider of a given address family so that +// sockets of that type can be created via socket() and/or socketpair() +// syscalls. +// +// This should only be called during the initialization of the address family. +func RegisterProviderVFS2(family int, provider ProviderVFS2) { + familiesVFS2[family] = append(familiesVFS2[family], provider) +} + +// NewVFS2 creates a new socket with the given family, type and protocol. +func NewVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) { + for _, p := range familiesVFS2[family] { + s, err := p.Socket(t, stype, protocol) + if err != nil { + return nil, err + } + if s != nil { + t.Kernel().RecordSocketVFS2(s) + return s, nil + } + } + + return nil, syserr.ErrAddressFamilyNotSupported +} + +// PairVFS2 creates a new connected socket pair with the given family, type and +// protocol. +func PairVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) { + providers, ok := familiesVFS2[family] + if !ok { + return nil, nil, syserr.ErrAddressFamilyNotSupported + } + + for _, p := range providers { + s1, s2, err := p.Pair(t, stype, protocol) + if err != nil { + return nil, nil, err + } + if s1 != nil && s2 != nil { + k := t.Kernel() + k.RecordSocketVFS2(s1) + k.RecordSocketVFS2(s2) + return s1, s2, nil + } + } + + return nil, nil, syserr.ErrSocketNotSupported +} + // SendReceiveTimeout stores timeouts for send and receive calls. // // It is meant to be embedded into Socket implementations to help satisfy the @@ -317,7 +408,6 @@ func emitUnimplementedEvent(t *kernel.Task, name int) { linux.SO_MARK, linux.SO_MAX_PACING_RATE, linux.SO_NOFCS, - linux.SO_NO_CHECK, linux.SO_OOBINLINE, linux.SO_PASSCRED, linux.SO_PASSSEC, diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index 5b6a154f6..cb953e4dc 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -1,35 +1,54 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") +load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) +go_template_instance( + name = "socket_refs", + out = "socket_refs.go", + package = "unix", + prefix = "socketOpsCommon", + template = "//pkg/refs_vfs2:refs_template", + types = { + "T": "socketOpsCommon", + }, +) + go_library( name = "unix", srcs = [ "device.go", "io.go", + "socket_refs.go", "unix.go", + "unix_vfs2.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/unix", visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", + "//pkg/fspath", + "//pkg/log", "//pkg/refs", + "//pkg/safemem", "//pkg/sentry/arch", - "//pkg/sentry/context", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/fs/fsutil", + "//pkg/sentry/fs/lock", + "//pkg/sentry/fsimpl/sockfs", "//pkg/sentry/kernel", "//pkg/sentry/kernel/time", - "//pkg/sentry/safemem", "//pkg/sentry/socket", "//pkg/sentry/socket/control", "//pkg/sentry/socket/netstack", "//pkg/sentry/socket/unix/transport", - "//pkg/sentry/usermem", + "//pkg/sentry/vfs", "//pkg/syserr", "//pkg/syserror", "//pkg/tcpip", + "//pkg/usermem", "//pkg/waiter", + "//tools/go_marshal/marshal", ], ) diff --git a/pkg/sentry/socket/unix/io.go b/pkg/sentry/socket/unix/io.go index 2ec1a662d..129949990 100644 --- a/pkg/sentry/socket/unix/io.go +++ b/pkg/sentry/socket/unix/io.go @@ -15,8 +15,8 @@ package unix import ( - "gvisor.dev/gvisor/pkg/sentry/context" - "gvisor.dev/gvisor/pkg/sentry/safemem" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/tcpip" ) @@ -83,6 +83,19 @@ type EndpointReader struct { ControlTrunc bool } +// Truncate calls RecvMsg on the endpoint without writing to a destination. +func (r *EndpointReader) Truncate() error { + // Ignore bytes read since it will always be zero. + _, ms, c, ct, err := r.Endpoint.RecvMsg(r.Ctx, [][]byte{}, r.Creds, r.NumRights, r.Peek, r.From) + r.Control = c + r.ControlTrunc = ct + r.MsgSize = ms + if err != nil { + return err.ToError() + } + return nil +} + // ReadToBlocks implements safemem.Reader.ReadToBlocks. func (r *EndpointReader) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { return safemem.FromVecReaderFunc{func(bufs [][]byte) (int64, error) { diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 788ad70d2..c708b6030 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -1,4 +1,4 @@ -load("//tools/go_stateify:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library") load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) @@ -25,13 +25,14 @@ go_library( "transport_message_list.go", "unix.go", ], - importpath = "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport", visibility = ["//:sandbox"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/ilist", + "//pkg/log", "//pkg/refs", - "//pkg/sentry/context", + "//pkg/sync", "//pkg/syserr", "//pkg/tcpip", "//pkg/tcpip/buffer", diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go index dea11e253..c67b602f0 100644 --- a/pkg/sentry/socket/unix/transport/connectioned.go +++ b/pkg/sentry/socket/unix/transport/connectioned.go @@ -15,10 +15,9 @@ package transport import ( - "sync" - "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/waiter" @@ -212,7 +211,7 @@ func (e *connectionedEndpoint) Listening() bool { // The socket will be a fresh state after a call to close and may be reused. // That is, close may be used to "unbind" or "disconnect" the socket in error // paths. -func (e *connectionedEndpoint) Close() { +func (e *connectionedEndpoint) Close(ctx context.Context) { e.Lock() var c ConnectedEndpoint var r Receiver @@ -234,7 +233,7 @@ func (e *connectionedEndpoint) Close() { case e.Listening(): close(e.acceptedChan) for n := range e.acceptedChan { - n.Close() + n.Close(ctx) } e.acceptedChan = nil e.path = "" @@ -242,18 +241,18 @@ func (e *connectionedEndpoint) Close() { e.Unlock() if c != nil { c.CloseNotify() - c.Release() + c.Release(ctx) } if r != nil { r.CloseNotify() - r.Release() + r.Release(ctx) } } // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce ConnectingEndpoint, returnConnect func(Receiver, ConnectedEndpoint)) *syserr.Error { if ce.Type() != e.stype { - return syserr.ErrConnectionRefused + return syserr.ErrWrongProtocolForSocket } // Check if ce is e to avoid a deadlock. @@ -341,7 +340,7 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn return nil default: // Busy; return ECONNREFUSED per spec. - ne.Close() + ne.Close(ctx) e.Unlock() ce.Unlock() return syserr.ErrConnectionRefused @@ -477,6 +476,9 @@ func (e *connectionedEndpoint) Readiness(mask waiter.EventMask) waiter.EventMask // State implements socket.Socket.State. func (e *connectionedEndpoint) State() uint32 { + e.Lock() + defer e.Unlock() + if e.Connected() { return linux.SS_CONNECTED } diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go index 0322dec0b..70ee8f9b8 100644 --- a/pkg/sentry/socket/unix/transport/connectionless.go +++ b/pkg/sentry/socket/unix/transport/connectionless.go @@ -16,7 +16,7 @@ package transport import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/waiter" @@ -54,10 +54,10 @@ func (e *connectionlessEndpoint) isBound() bool { // Close puts the endpoint in a closed state and frees all resources associated // with it. -func (e *connectionlessEndpoint) Close() { +func (e *connectionlessEndpoint) Close(ctx context.Context) { e.Lock() if e.connected != nil { - e.connected.Release() + e.connected.Release(ctx) e.connected = nil } @@ -71,7 +71,7 @@ func (e *connectionlessEndpoint) Close() { e.Unlock() r.CloseNotify() - r.Release() + r.Release(ctx) } // BidirectionalConnect implements BoundEndpoint.BidirectionalConnect. @@ -108,10 +108,10 @@ func (e *connectionlessEndpoint) SendMsg(ctx context.Context, data [][]byte, c C if err != nil { return 0, syserr.ErrInvalidEndpointState } - defer connected.Release() + defer connected.Release(ctx) e.Lock() - n, notify, err := connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) + n, notify, err := connected.Send(ctx, data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) e.Unlock() if notify { @@ -135,7 +135,7 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi e.Lock() if e.connected != nil { - e.connected.Release() + e.connected.Release(ctx) } e.connected = connected e.Unlock() diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go index e27b1c714..ef6043e19 100644 --- a/pkg/sentry/socket/unix/transport/queue.go +++ b/pkg/sentry/socket/unix/transport/queue.go @@ -15,10 +15,12 @@ package transport import ( - "sync" - + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/waiter" ) @@ -56,10 +58,10 @@ func (q *queue) Close() { // Both the read and write queues must be notified after resetting: // q.ReaderQueue.Notify(waiter.EventIn) // q.WriterQueue.Notify(waiter.EventOut) -func (q *queue) Reset() { +func (q *queue) Reset(ctx context.Context) { q.mu.Lock() for cur := q.dataList.Front(); cur != nil; cur = cur.Next() { - cur.Release() + cur.Release(ctx) } q.dataList.Reset() q.used = 0 @@ -67,8 +69,8 @@ func (q *queue) Reset() { } // DecRef implements RefCounter.DecRef with destructor q.Reset. -func (q *queue) DecRef() { - q.DecRefWithDestructor(q.Reset) +func (q *queue) DecRef(ctx context.Context) { + q.DecRefWithDestructor(ctx, q.Reset) // We don't need to notify after resetting because no one cares about // this queue after all references have been dropped. } @@ -101,12 +103,16 @@ func (q *queue) IsWritable() bool { // Enqueue adds an entry to the data queue if room is available. // +// If discardEmpty is true and there are zero bytes of data, the packet is +// dropped. +// // If truncate is true, Enqueue may truncate the message before enqueuing it. -// Otherwise, the entire message must fit. If n < e.Length(), err indicates why. +// Otherwise, the entire message must fit. If l is less than the size of data, +// err indicates why. // // If notify is true, ReaderQueue.Notify must be called: // q.ReaderQueue.Notify(waiter.EventIn) -func (q *queue) Enqueue(e *message, truncate bool) (l int64, notify bool, err *syserr.Error) { +func (q *queue) Enqueue(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress, discardEmpty bool, truncate bool) (l int64, notify bool, err *syserr.Error) { q.mu.Lock() if q.closed { @@ -114,9 +120,16 @@ func (q *queue) Enqueue(e *message, truncate bool) (l int64, notify bool, err *s return 0, false, syserr.ErrClosedForSend } - free := q.limit - q.used + for _, d := range data { + l += int64(len(d)) + } + if discardEmpty && l == 0 { + q.mu.Unlock() + c.Release(ctx) + return 0, false, nil + } - l = e.Length() + free := q.limit - q.used if l > free && truncate { if free == 0 { @@ -125,8 +138,7 @@ func (q *queue) Enqueue(e *message, truncate bool) (l int64, notify bool, err *s return 0, false, syserr.ErrWouldBlock } - e.Truncate(free) - l = e.Length() + l = free err = syserr.ErrWouldBlock } @@ -137,14 +149,26 @@ func (q *queue) Enqueue(e *message, truncate bool) (l int64, notify bool, err *s } if l > free { - // Message can't fit right now. + // Message can't fit right now, and could not be truncated. q.mu.Unlock() return 0, false, syserr.ErrWouldBlock } + // Aggregate l bytes of data. This will truncate the data if l is less than + // the total bytes held in data. + v := make([]byte, l) + for i, b := 0, v; i < len(data) && len(b) > 0; i++ { + n := copy(b, data[i]) + b = b[n:] + } + notify = q.dataList.Front() == nil q.used += l - q.dataList.PushBack(e) + q.dataList.PushBack(&message{ + Data: buffer.View(v), + Control: c, + Address: from, + }) q.mu.Unlock() diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index 529a7a7a9..475d7177e 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -16,11 +16,12 @@ package transport import ( - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/context" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/log" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -36,7 +37,7 @@ type RightsControlMessage interface { Clone() RightsControlMessage // Release releases any resources owned by the RightsControlMessage. - Release() + Release(ctx context.Context) } // A CredentialsControlMessage is a control message containing Unix credentials. @@ -73,9 +74,9 @@ func (c *ControlMessages) Clone() ControlMessages { } // Release releases both the credentials and the rights. -func (c *ControlMessages) Release() { +func (c *ControlMessages) Release(ctx context.Context) { if c.Rights != nil { - c.Rights.Release() + c.Rights.Release(ctx) } *c = ControlMessages{} } @@ -89,7 +90,7 @@ type Endpoint interface { // Close puts the endpoint in a closed state and frees all resources // associated with it. - Close() + Close(ctx context.Context) // RecvMsg reads data and a control message from the endpoint. This method // does not block if there is no data pending. @@ -175,17 +176,25 @@ type Endpoint interface { // types. SetSockOpt(opt interface{}) *tcpip.Error + // SetSockOptBool sets a socket option for simple cases when a value has + // the int type. + SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error + // SetSockOptInt sets a socket option for simple cases when a value has // the int type. - SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error + SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error // GetSockOpt gets a socket option. opt should be a pointer to one of the // tcpip.*Option types. GetSockOpt(opt interface{}) *tcpip.Error + // GetSockOptBool gets a socket option for simple cases when a return + // value has the int type. + GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) + // GetSockOptInt gets a socket option for simple cases when a return // value has the int type. - GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) + GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) // State returns the current state of the socket, as represented by Linux in // procfs. @@ -243,7 +252,7 @@ type BoundEndpoint interface { // Release releases any resources held by the BoundEndpoint. It must be // called before dropping all references to a BoundEndpoint returned by a // function. - Release() + Release(ctx context.Context) } // message represents a message passed over a Unix domain socket. @@ -272,8 +281,8 @@ func (m *message) Length() int64 { } // Release releases any resources held by the message. -func (m *message) Release() { - m.Control.Release() +func (m *message) Release(ctx context.Context) { + m.Control.Release(ctx) } // Peek returns a copy of the message. @@ -295,7 +304,7 @@ type Receiver interface { // See Endpoint.RecvMsg for documentation on shared arguments. // // notify indicates if RecvNotify should be called. - Recv(data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error) + Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (recvLen, msgLen int64, cm ControlMessages, CMTruncated bool, source tcpip.FullAddress, notify bool, err *syserr.Error) // RecvNotify notifies the Receiver of a successful Recv. This must not be // called while holding any endpoint locks. @@ -324,7 +333,7 @@ type Receiver interface { // Release releases any resources owned by the Receiver. It should be // called before droping all references to a Receiver. - Release() + Release(ctx context.Context) } // queueReceiver implements Receiver for datagram sockets. @@ -335,7 +344,7 @@ type queueReceiver struct { } // Recv implements Receiver.Recv. -func (q *queueReceiver) Recv(data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (q *queueReceiver) Recv(ctx context.Context, data [][]byte, creds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { var m *message var notify bool var err *syserr.Error @@ -389,8 +398,8 @@ func (q *queueReceiver) RecvMaxQueueSize() int64 { } // Release implements Receiver.Release. -func (q *queueReceiver) Release() { - q.readQueue.DecRef() +func (q *queueReceiver) Release(ctx context.Context) { + q.readQueue.DecRef(ctx) } // streamQueueReceiver implements Receiver for stream sockets. @@ -447,7 +456,7 @@ func (q *streamQueueReceiver) RecvMaxQueueSize() int64 { } // Recv implements Receiver.Recv. -func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { +func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds bool, numRights int, peek bool) (int64, int64, ControlMessages, bool, tcpip.FullAddress, bool, *syserr.Error) { q.mu.Lock() defer q.mu.Unlock() @@ -493,7 +502,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, var cmTruncated bool if c.Rights != nil && numRights == 0 { - c.Rights.Release() + c.Rights.Release(ctx) c.Rights = nil cmTruncated = true } @@ -548,7 +557,7 @@ func (q *streamQueueReceiver) Recv(data [][]byte, wantCreds bool, numRights int, // Consume rights. if numRights == 0 { cmTruncated = true - q.control.Rights.Release() + q.control.Rights.Release(ctx) } else { c.Rights = q.control.Rights haveRights = true @@ -573,7 +582,7 @@ type ConnectedEndpoint interface { // // syserr.ErrWouldBlock can be returned along with a partial write if // the caller should block to send the rest of the data. - Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error) + Send(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress) (n int64, notify bool, err *syserr.Error) // SendNotify notifies the ConnectedEndpoint of a successful Send. This // must not be called while holding any endpoint locks. @@ -607,7 +616,7 @@ type ConnectedEndpoint interface { // Release releases any resources owned by the ConnectedEndpoint. It should // be called before droping all references to a ConnectedEndpoint. - Release() + Release(ctx context.Context) // CloseUnread sets the fact that this end is closed with unread data to // the peer socket. @@ -645,35 +654,22 @@ func (e *connectedEndpoint) GetLocalAddress() (tcpip.FullAddress, *tcpip.Error) } // Send implements ConnectedEndpoint.Send. -func (e *connectedEndpoint) Send(data [][]byte, controlMessages ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { - var l int64 - for _, d := range data { - l += int64(len(d)) - } - +func (e *connectedEndpoint) Send(ctx context.Context, data [][]byte, c ControlMessages, from tcpip.FullAddress) (int64, bool, *syserr.Error) { + discardEmpty := false truncate := false if e.endpoint.Type() == linux.SOCK_STREAM { - // Since stream sockets don't preserve message boundaries, we - // can write only as much of the message as fits in the queue. - truncate = true - // Discard empty stream packets. Since stream sockets don't // preserve message boundaries, sending zero bytes is a no-op. // In Linux, the receiver actually uses a zero-length receive // as an indication that the stream was closed. - if l == 0 { - controlMessages.Release() - return 0, false, nil - } - } + discardEmpty = true - v := make([]byte, 0, l) - for _, d := range data { - v = append(v, d...) + // Since stream sockets don't preserve message boundaries, we + // can write only as much of the message as fits in the queue. + truncate = true } - l, notify, err := e.writeQueue.Enqueue(&message{Data: buffer.View(v), Control: controlMessages, Address: from}, truncate) - return int64(l), notify, err + return e.writeQueue.Enqueue(ctx, data, c, from, discardEmpty, truncate) } // SendNotify implements ConnectedEndpoint.SendNotify. @@ -711,8 +707,8 @@ func (e *connectedEndpoint) SendMaxQueueSize() int64 { } // Release implements ConnectedEndpoint.Release. -func (e *connectedEndpoint) Release() { - e.writeQueue.DecRef() +func (e *connectedEndpoint) Release(ctx context.Context) { + e.writeQueue.DecRef(ctx) } // CloseUnread implements ConnectedEndpoint.CloseUnread. @@ -802,7 +798,7 @@ func (e *baseEndpoint) RecvMsg(ctx context.Context, data [][]byte, creds bool, n return 0, 0, ControlMessages{}, false, syserr.ErrNotConnected } - recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(data, creds, numRights, peek) + recvLen, msgLen, cms, cmt, a, notify, err := e.receiver.Recv(ctx, data, creds, numRights, peek) e.Unlock() if err != nil { return 0, 0, ControlMessages{}, false, err @@ -831,7 +827,7 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess return 0, syserr.ErrAlreadyConnected } - n, notify, err := e.connected.Send(data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) + n, notify, err := e.connected.Send(ctx, data, c, tcpip.FullAddress{Addr: tcpip.Address(e.path)}) e.Unlock() if notify { @@ -843,19 +839,46 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess // SetSockOpt sets a socket option. Currently not supported. func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error { - switch v := opt.(type) { + return nil +} + +func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error { + switch opt { + case tcpip.BroadcastOption: case tcpip.PasscredOption: - e.setPasscred(v != 0) - return nil + e.setPasscred(v) + case tcpip.ReuseAddressOption: + default: + log.Warningf("Unsupported socket option: %d", opt) } return nil } -func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOpt, v int) *tcpip.Error { +func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { + switch opt { + case tcpip.SendBufferSizeOption: + case tcpip.ReceiveBufferSizeOption: + default: + log.Warningf("Unsupported socket option: %d", opt) + } return nil } -func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { +func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { + switch opt { + case tcpip.KeepaliveEnabledOption: + return false, nil + + case tcpip.PasscredOption: + return e.Passcred(), nil + + default: + log.Warningf("Unsupported socket option: %d", opt) + return false, tcpip.ErrUnknownProtocolOption + } +} + +func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { switch opt { case tcpip.ReceiveQueueSizeOption: v := 0 @@ -911,29 +934,19 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOpt) (int, *tcpip.Error) { return int(v), nil default: + log.Warningf("Unsupported socket option: %d", opt) return -1, tcpip.ErrUnknownProtocolOption } } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch o := opt.(type) { + switch opt.(type) { case tcpip.ErrorOption: return nil - case *tcpip.PasscredOption: - if e.Passcred() { - *o = tcpip.PasscredOption(1) - } else { - *o = tcpip.PasscredOption(0) - } - return nil - - case *tcpip.KeepaliveEnabledOption: - *o = 0 - return nil - default: + log.Warningf("Unsupported socket option: %T", opt) return tcpip.ErrUnknownProtocolOption } } @@ -988,6 +1001,6 @@ func (e *baseEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error) { } // Release implements BoundEndpoint.Release. -func (*baseEndpoint) Release() { +func (*baseEndpoint) Release(context.Context) { // Binding a baseEndpoint doesn't take a reference. } diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index 1aaae8487..b7e8e4325 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -22,9 +22,9 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/arch" - "gvisor.dev/gvisor/pkg/sentry/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -33,11 +33,13 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/control" "gvisor.dev/gvisor/pkg/sentry/socket/netstack" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" - "gvisor.dev/gvisor/pkg/sentry/usermem" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserr" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" ) // SocketOperations is a Unix socket. It is similar to a netstack socket, @@ -52,17 +54,14 @@ type SocketOperations struct { fsutil.FileNoSplice `state:"nosave"` fsutil.FileNoopFlush `state:"nosave"` fsutil.FileUseInodeUnstableAttr `state:"nosave"` - refs.AtomicRefCount - socket.SendReceiveTimeout - ep transport.Endpoint - stype linux.SockType + socketOpsCommon } // New creates a new unix socket. func New(ctx context.Context, endpoint transport.Endpoint, stype linux.SockType) *fs.File { dirent := socket.NewDirent(ctx, unixSocketDevice) - defer dirent.DecRef() + defer dirent.DecRef(ctx) return NewWithDirent(ctx, dirent, endpoint, stype, fs.FileFlags{Read: true, Write: true, NonSeekable: true}) } @@ -75,29 +74,51 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty } s := SocketOperations{ - ep: ep, - stype: stype, + socketOpsCommon: socketOpsCommon{ + ep: ep, + stype: stype, + }, } - s.EnableLeakCheck("unix.SocketOperations") + s.EnableLeakCheck() return fs.NewFile(ctx, d, flags, &s) } +// socketOpsCommon contains the socket operations common to VFS1 and VFS2. +// +// +stateify savable +type socketOpsCommon struct { + socketOpsCommonRefs + socket.SendReceiveTimeout + + ep transport.Endpoint + stype linux.SockType + + // abstractName and abstractNamespace indicate the name and namespace of the + // socket if it is bound to an abstract socket namespace. Once the socket is + // bound, they cannot be modified. + abstractName string + abstractNamespace *kernel.AbstractSocketNamespace +} + // DecRef implements RefCounter.DecRef. -func (s *SocketOperations) DecRef() { - s.DecRefWithDestructor(func() { - s.ep.Close() +func (s *socketOpsCommon) DecRef(ctx context.Context) { + s.socketOpsCommonRefs.DecRef(func() { + s.ep.Close(ctx) + if s.abstractNamespace != nil { + s.abstractNamespace.Remove(s.abstractName, s) + } }) } // Release implemements fs.FileOperations.Release. -func (s *SocketOperations) Release() { +func (s *socketOpsCommon) Release(ctx context.Context) { // Release only decrements a reference on s because s may be referenced in // the abstract socket namespace. - s.DecRef() + s.DecRef(ctx) } -func (s *SocketOperations) isPacket() bool { +func (s *socketOpsCommon) isPacket() bool { switch s.stype { case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET: return true @@ -110,16 +131,22 @@ func (s *SocketOperations) isPacket() bool { } // Endpoint extracts the transport.Endpoint. -func (s *SocketOperations) Endpoint() transport.Endpoint { +func (s *socketOpsCommon) Endpoint() transport.Endpoint { return s.ep } // extractPath extracts and validates the address. func extractPath(sockaddr []byte) (string, *syserr.Error) { - addr, _, err := netstack.AddressAndFamily(linux.AF_UNIX, sockaddr, true /* strict */) + addr, family, err := netstack.AddressAndFamily(sockaddr) if err != nil { + if err == syserr.ErrAddressFamilyNotSupported { + err = syserr.ErrInvalidArgument + } return "", err } + if family != linux.AF_UNIX { + return "", syserr.ErrInvalidArgument + } // The address is trimmed by GetAddress. p := string(addr.Addr) @@ -137,7 +164,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) { // GetPeerName implements the linux syscall getpeername(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.ep.GetRemoteAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -149,7 +176,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, // GetSockName implements the linux syscall getsockname(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { +func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) { addr, err := s.ep.GetLocalAddress() if err != nil { return nil, 0, syserr.TranslateNetstackError(err) @@ -166,13 +193,13 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO, // GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) { +func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) } // Listen implements the linux syscall listen(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error { +func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error { return s.ep.Listen(backlog) } @@ -215,7 +242,7 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, } ns := New(t, ep, s.stype) - defer ns.DecRef() + defer ns.DecRef(t) if flags&linux.SOCK_NONBLOCK != 0 { flags := ns.Flags() @@ -265,17 +292,21 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { if t.IsNetworkNamespaced() { return syserr.ErrInvalidEndpointState } - if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil { + asn := t.AbstractSockets() + name := p[1:] + if err := asn.Bind(t, name, bep, s); err != nil { // syserr.ErrPortInUse corresponds to EADDRINUSE. return syserr.ErrPortInUse } + s.abstractName = name + s.abstractNamespace = asn } else { // The parent and name. var d *fs.Dirent var name string cwd := t.FSContext().WorkingDirectory() - defer cwd.DecRef() + defer cwd.DecRef(t) // Is there no slash at all? if !strings.Contains(p, "/") { @@ -283,7 +314,7 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { name = p } else { root := t.FSContext().RootDirectory() - defer root.DecRef() + defer root.DecRef(t) // Find the last path component, we know that something follows // that final slash, otherwise extractPath() would have failed. lastSlash := strings.LastIndex(p, "/") @@ -299,16 +330,21 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { // No path available. return syserr.ErrNoSuchFile } - defer d.DecRef() + defer d.DecRef(t) name = p[lastSlash+1:] } // Create the socket. + // + // Note that the file permissions here are not set correctly (see + // gvisor.dev/issue/2324). There is no convenient way to get permissions + // on the socket referred to by s, so we will leave this discrepancy + // unresolved until VFS2 replaces this code. childDir, err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}}) if err != nil { return syserr.ErrPortInUse } - childDir.DecRef() + childDir.DecRef(t) } return nil @@ -339,41 +375,76 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint, return ep, nil } + if kernel.VFS2Enabled { + p := fspath.Parse(path) + root := t.FSContext().RootDirectoryVFS2() + start := root + relPath := !p.Absolute + if relPath { + start = t.FSContext().WorkingDirectoryVFS2() + } + pop := vfs.PathOperation{ + Root: root, + Start: start, + Path: p, + FollowFinalSymlink: true, + } + ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop, &vfs.BoundEndpointOptions{path}) + root.DecRef(t) + if relPath { + start.DecRef(t) + } + if e != nil { + return nil, syserr.FromError(e) + } + return ep, nil + } + // Find the node in the filesystem. root := t.FSContext().RootDirectory() cwd := t.FSContext().WorkingDirectory() remainingTraversals := uint(fs.DefaultTraversalLimit) d, e := t.MountNamespace().FindInode(t, root, cwd, path, &remainingTraversals) - cwd.DecRef() - root.DecRef() + cwd.DecRef(t) + root.DecRef(t) if e != nil { return nil, syserr.FromError(e) } // Extract the endpoint if one is there. ep := d.Inode.BoundEndpoint(path) - d.DecRef() + d.DecRef(t) if ep == nil { // No socket! return nil, syserr.ErrConnectionRefused } - return ep, nil } // Connect implements the linux syscall connect(2) for unix sockets. -func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { +func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error { ep, err := extractEndpoint(t, sockaddr) if err != nil { return err } - defer ep.Release() + defer ep.Release(t) // Connect the server endpoint. - return s.ep.Connect(t, ep) + err = s.ep.Connect(t, ep) + + if err == syserr.ErrWrongProtocolForSocket { + // Linux for abstract sockets returns ErrConnectionRefused + // instead of ErrWrongProtocolForSocket. + path, _ := extractPath(sockaddr) + if len(path) > 0 && path[0] == 0 { + err = syserr.ErrConnectionRefused + } + } + + return err } -// Writev implements fs.FileOperations.Write. +// Write implements fs.FileOperations.Write. func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) { t := kernel.TaskFromContext(ctx) ctrl := control.New(t, s.ep, nil) @@ -393,7 +464,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO // SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by // a transport.Endpoint. -func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { +func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) { w := EndpointWriter{ Ctx: t, Endpoint: s.ep, @@ -401,15 +472,25 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] To: nil, } if len(to) > 0 { - ep, err := extractEndpoint(t, to) - if err != nil { - return 0, err - } - defer ep.Release() - w.To = ep + switch s.stype { + case linux.SOCK_SEQPACKET: + to = nil + case linux.SOCK_STREAM: + if s.State() == linux.SS_CONNECTED { + return 0, syserr.ErrAlreadyConnected + } + return 0, syserr.ErrNotSupported + default: + ep, err := extractEndpoint(t, to) + if err != nil { + return 0, err + } + defer ep.Release(t) + w.To = ep - if ep.Passcred() && w.Control.Credentials == nil { - w.Control.Credentials = control.MakeCreds(t) + if ep.Passcred() && w.Control.Credentials == nil { + w.Control.Credentials = control.MakeCreds(t) + } } } @@ -447,27 +528,27 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to [] } // Passcred implements transport.Credentialer.Passcred. -func (s *SocketOperations) Passcred() bool { +func (s *socketOpsCommon) Passcred() bool { return s.ep.Passcred() } // ConnectedPasscred implements transport.Credentialer.ConnectedPasscred. -func (s *SocketOperations) ConnectedPasscred() bool { +func (s *socketOpsCommon) ConnectedPasscred() bool { return s.ep.ConnectedPasscred() } // Readiness implements waiter.Waitable.Readiness. -func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask { +func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask { return s.ep.Readiness(mask) } // EventRegister implements waiter.Waitable.EventRegister. -func (s *SocketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) { +func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) { s.ep.EventRegister(e, mask) } // EventUnregister implements waiter.Waitable.EventUnregister. -func (s *SocketOperations) EventUnregister(e *waiter.Entry) { +func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) { s.ep.EventUnregister(e) } @@ -479,7 +560,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa // Shutdown implements the linux syscall shutdown(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error { +func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error { f, err := netstack.ConvertShutdown(how) if err != nil { return err @@ -505,7 +586,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS // RecvMsg implements the linux syscall recvmsg(2) for sockets backed by // a transport.Endpoint. -func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { +func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) { trunc := flags&linux.MSG_TRUNC != 0 peek := flags&linux.MSG_PEEK != 0 dontWait := flags&linux.MSG_DONTWAIT != 0 @@ -541,8 +622,27 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags if senderRequested { r.From = &tcpip.FullAddress{} } + + doRead := func() (int64, error) { + return dst.CopyOutFrom(t, &r) + } + + // If MSG_TRUNC is set with a zero byte destination then we still need + // to read the message and discard it, or in the case where MSG_PEEK is + // set, leave it be. In both cases the full message length must be + // returned. + if trunc && dst.Addrs.NumBytes() == 0 { + doRead = func() (int64, error) { + err := r.Truncate() + // Always return zero for bytes read since the destination size is + // zero. + return 0, err + } + + } + var total int64 - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock || dontWait { + if n, err := doRead(); err != syserror.ErrWouldBlock || dontWait { var from linux.SockAddr var fromLen uint32 if r.From != nil && len([]byte(r.From.Addr)) != 0 { @@ -577,7 +677,7 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags defer s.EventUnregister(&e) for { - if n, err := dst.CopyOutFrom(t, &r); err != syserror.ErrWouldBlock { + if n, err := doRead(); err != syserror.ErrWouldBlock { var from linux.SockAddr var fromLen uint32 if r.From != nil { @@ -623,12 +723,12 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags } // State implements socket.Socket.State. -func (s *SocketOperations) State() uint32 { +func (s *socketOpsCommon) State() uint32 { return s.ep.State() } // Type implements socket.Socket.Type. -func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) { +func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) { // Unix domain sockets always have a protocol of 0. return linux.AF_UNIX, s.stype, 0 } @@ -681,4 +781,5 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F func init() { socket.RegisterProvider(linux.AF_UNIX, &provider{}) + socket.RegisterProviderVFS2(linux.AF_UNIX, &providerVFS2{}) } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go new file mode 100644 index 000000000..d066ef8ab --- /dev/null +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -0,0 +1,376 @@ +// Copyright 2018 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package unix + +import ( + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/sentry/arch" + fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/sentry/socket" + "gvisor.dev/gvisor/pkg/sentry/socket/control" + "gvisor.dev/gvisor/pkg/sentry/socket/netstack" + "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/usermem" + "gvisor.dev/gvisor/pkg/waiter" + "gvisor.dev/gvisor/tools/go_marshal/marshal" +) + +// SocketVFS2 implements socket.SocketVFS2 (and by extension, +// vfs.FileDescriptionImpl) for Unix sockets. +type SocketVFS2 struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.DentryMetadataFileDescriptionImpl + vfs.LockFD + + socketOpsCommon +} + +var _ = socket.SocketVFS2(&SocketVFS2{}) + +// NewSockfsFile creates a new socket file in the global sockfs mount and +// returns a corresponding file description. +func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) { + mnt := t.Kernel().SocketMount() + d := sockfs.NewDentry(t.Credentials(), mnt) + + fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{}) + if err != nil { + return nil, syserr.FromError(err) + } + return fd, nil +} + +// NewFileDescription creates and returns a socket file description +// corresponding to the given mount and dentry. +func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint32, mnt *vfs.Mount, d *vfs.Dentry, locks *vfs.FileLocks) (*vfs.FileDescription, error) { + // You can create AF_UNIX, SOCK_RAW sockets. They're the same as + // SOCK_DGRAM and don't require CAP_NET_RAW. + if stype == linux.SOCK_RAW { + stype = linux.SOCK_DGRAM + } + + sock := &SocketVFS2{ + socketOpsCommon: socketOpsCommon{ + ep: ep, + stype: stype, + }, + } + sock.LockFD.Init(locks) + vfsfd := &sock.vfsfd + if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{ + DenyPRead: true, + DenyPWrite: true, + UseDentryMetadata: true, + }); err != nil { + return nil, err + } + return vfsfd, nil +} + +// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by +// a transport.Endpoint. +func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { + return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen) +} + +// blockingAccept implements a blocking version of accept(2), that is, if no +// connections are ready to be accept, it will block until one becomes ready. +func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) { + // Register for notifications. + e, ch := waiter.NewChannelEntry(nil) + s.socketOpsCommon.EventRegister(&e, waiter.EventIn) + defer s.socketOpsCommon.EventUnregister(&e) + + // Try to accept the connection; if it fails, then wait until we get a + // notification. + for { + if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock { + return ep, err + } + + if err := t.Block(ch); err != nil { + return nil, syserr.FromError(err) + } + } +} + +// Accept implements the linux syscall accept(2) for sockets backed by +// a transport.Endpoint. +func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) { + // Issue the accept request to get the new endpoint. + ep, err := s.ep.Accept() + if err != nil { + if err != syserr.ErrWouldBlock || !blocking { + return 0, nil, 0, err + } + + var err *syserr.Error + ep, err = s.blockingAccept(t) + if err != nil { + return 0, nil, 0, err + } + } + + ns, err := NewSockfsFile(t, ep, s.stype) + if err != nil { + return 0, nil, 0, err + } + defer ns.DecRef(t) + + if flags&linux.SOCK_NONBLOCK != 0 { + ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK) + } + + var addr linux.SockAddr + var addrLen uint32 + if peerRequested { + // Get address of the peer. + var err *syserr.Error + addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t) + if err != nil { + return 0, nil, 0, err + } + } + + fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{ + CloseOnExec: flags&linux.SOCK_CLOEXEC != 0, + }) + if e != nil { + return 0, nil, 0, syserr.FromError(e) + } + + t.Kernel().RecordSocketVFS2(ns) + return fd, addr, addrLen, nil +} + +// Bind implements the linux syscall bind(2) for unix sockets. +func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error { + p, e := extractPath(sockaddr) + if e != nil { + return e + } + + bep, ok := s.ep.(transport.BoundEndpoint) + if !ok { + // This socket can't be bound. + return syserr.ErrInvalidArgument + } + + return s.ep.Bind(tcpip.FullAddress{Addr: tcpip.Address(p)}, func() *syserr.Error { + // Is it abstract? + if p[0] == 0 { + if t.IsNetworkNamespaced() { + return syserr.ErrInvalidEndpointState + } + asn := t.AbstractSockets() + name := p[1:] + if err := asn.Bind(t, name, bep, s); err != nil { + // syserr.ErrPortInUse corresponds to EADDRINUSE. + return syserr.ErrPortInUse + } + s.abstractName = name + s.abstractNamespace = asn + } else { + path := fspath.Parse(p) + root := t.FSContext().RootDirectoryVFS2() + defer root.DecRef(t) + start := root + relPath := !path.Absolute + if relPath { + start = t.FSContext().WorkingDirectoryVFS2() + defer start.DecRef(t) + } + pop := vfs.PathOperation{ + Root: root, + Start: start, + Path: path, + } + stat, err := s.vfsfd.Stat(t, vfs.StatOptions{Mask: linux.STATX_MODE}) + if err != nil { + return syserr.FromError(err) + } + err = t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{ + // File permissions correspond to net/unix/af_unix.c:unix_bind. + Mode: linux.FileMode(linux.S_IFSOCK | uint(stat.Mode)&^t.FSContext().Umask()), + Endpoint: bep, + }) + if err == syserror.EEXIST { + return syserr.ErrAddressInUse + } + return syserr.FromError(err) + } + + return nil + }) +} + +// Ioctl implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) { + return netstack.Ioctl(ctx, s.ep, uio, args) +} + +// PRead implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Read implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + if dst.NumBytes() == 0 { + return 0, nil + } + return dst.CopyOutFrom(ctx, &EndpointReader{ + Ctx: ctx, + Endpoint: s.ep, + NumRights: 0, + Peek: false, + From: nil, + }) +} + +// PWrite implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.ESPIPE +} + +// Write implements vfs.FileDescriptionImpl. +func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + // All flags other than RWF_NOWAIT should be ignored. + // TODO(gvisor.dev/issue/2601): Support RWF_NOWAIT. + if opts.Flags != 0 { + return 0, syserror.EOPNOTSUPP + } + + t := kernel.TaskFromContext(ctx) + ctrl := control.New(t, s.ep, nil) + + if src.NumBytes() == 0 { + nInt, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil) + return int64(nInt), err.ToError() + } + + return src.CopyInTo(ctx, &EndpointWriter{ + Ctx: ctx, + Endpoint: s.ep, + Control: ctrl, + To: nil, + }) +} + +// Readiness implements waiter.Waitable.Readiness. +func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask { + return s.socketOpsCommon.Readiness(mask) +} + +// EventRegister implements waiter.Waitable.EventRegister. +func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) { + s.socketOpsCommon.EventRegister(e, mask) +} + +// EventUnregister implements waiter.Waitable.EventUnregister. +func (s *SocketVFS2) EventUnregister(e *waiter.Entry) { + s.socketOpsCommon.EventUnregister(e) +} + +// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by +// a transport.Endpoint. +func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error { + return netstack.SetSockOpt(t, s, s.ep, level, name, optVal) +} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (s *SocketVFS2) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return s.Locks().LockPOSIX(ctx, &s.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (s *SocketVFS2) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return s.Locks().UnlockPOSIX(ctx, &s.vfsfd, uid, start, length, whence) +} + +// providerVFS2 is a unix domain socket provider for VFS2. +type providerVFS2 struct{} + +func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) { + // Check arguments. + if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ { + return nil, syserr.ErrProtocolNotSupported + } + + // Create the endpoint and socket. + var ep transport.Endpoint + switch stype { + case linux.SOCK_DGRAM, linux.SOCK_RAW: + ep = transport.NewConnectionless(t) + case linux.SOCK_SEQPACKET, linux.SOCK_STREAM: + ep = transport.NewConnectioned(t, stype, t.Kernel()) + default: + return nil, syserr.ErrInvalidArgument + } + + f, err := NewSockfsFile(t, ep, stype) + if err != nil { + ep.Close(t) + return nil, err + } + return f, nil +} + +// Pair creates a new pair of AF_UNIX connected sockets. +func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) { + // Check arguments. + if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ { + return nil, nil, syserr.ErrProtocolNotSupported + } + + switch stype { + case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET, linux.SOCK_RAW: + // Ok + default: + return nil, nil, syserr.ErrInvalidArgument + } + + // Create the endpoints and sockets. + ep1, ep2 := transport.NewPair(t, stype, t.Kernel()) + s1, err := NewSockfsFile(t, ep1, stype) + if err != nil { + ep1.Close(t) + ep2.Close(t) + return nil, nil, err + } + s2, err := NewSockfsFile(t, ep2, stype) + if err != nil { + s1.DecRef(t) + ep2.Close(t) + return nil, nil, err + } + + return s1, s2, nil +} |