summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/control/BUILD1
-rw-r--r--pkg/sentry/socket/control/control.go71
-rw-r--r--pkg/sentry/socket/hostinet/socket.go11
-rw-r--r--pkg/sentry/socket/netfilter/BUILD3
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go21
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go332
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go128
-rw-r--r--pkg/sentry/socket/netfilter/targets.go34
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go11
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go13
-rw-r--r--pkg/sentry/socket/netlink/message.go15
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go293
-rw-r--r--pkg/sentry/socket/netstack/provider.go16
-rw-r--r--pkg/sentry/socket/netstack/stack.go76
-rw-r--r--pkg/sentry/socket/socket.go89
-rw-r--r--pkg/sentry/socket/unix/BUILD4
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD1
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go48
-rw-r--r--pkg/sentry/socket/unix/unix.go89
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go348
22 files changed, 1274 insertions, 332 deletions
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index 611fa22c3..c40c6d673 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
"//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/usermem",
diff --git a/pkg/sentry/socket/control/BUILD b/pkg/sentry/socket/control/BUILD
index 79e16d6e8..4d42d29cb 100644
--- a/pkg/sentry/socket/control/BUILD
+++ b/pkg/sentry/socket/control/BUILD
@@ -19,6 +19,7 @@ go_library(
"//pkg/sentry/socket",
"//pkg/sentry/socket/unix/transport",
"//pkg/syserror",
+ "//pkg/tcpip",
"//pkg/usermem",
],
)
diff --git a/pkg/sentry/socket/control/control.go b/pkg/sentry/socket/control/control.go
index 00265f15b..8834a1e1a 100644
--- a/pkg/sentry/socket/control/control.go
+++ b/pkg/sentry/socket/control/control.go
@@ -26,6 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
"gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -189,7 +190,7 @@ func putUint32(buf []byte, n uint32) []byte {
// putCmsg writes a control message header and as much data as will fit into
// the unused capacity of a buffer.
func putCmsg(buf []byte, flags int, msgType uint32, align uint, data []int32) ([]byte, int) {
- space := AlignDown(cap(buf)-len(buf), 4)
+ space := binary.AlignDown(cap(buf)-len(buf), 4)
// We can't write to space that doesn't exist, so if we are going to align
// the available space, we must align down.
@@ -282,19 +283,9 @@ func PackCredentials(t *kernel.Task, creds SCMCredentials, buf []byte, flags int
return putCmsg(buf, flags, linux.SCM_CREDENTIALS, align, c)
}
-// AlignUp rounds a length up to an alignment. align must be a power of 2.
-func AlignUp(length int, align uint) int {
- return (length + int(align) - 1) & ^(int(align) - 1)
-}
-
-// AlignDown rounds a down to an alignment. align must be a power of 2.
-func AlignDown(length int, align uint) int {
- return length & ^(int(align) - 1)
-}
-
// alignSlice extends a slice's length (up to the capacity) to align it.
func alignSlice(buf []byte, align uint) []byte {
- aligned := AlignUp(len(buf), align)
+ aligned := binary.AlignUp(len(buf), align)
if aligned > cap(buf) {
// Linux allows unaligned data if there isn't room for alignment.
// Since there isn't room for alignment, there isn't room for any
@@ -338,7 +329,7 @@ func PackTOS(t *kernel.Task, tos uint8, buf []byte) []byte {
}
// PackTClass packs an IPV6_TCLASS socket control message.
-func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte {
+func PackTClass(t *kernel.Task, tClass uint32, buf []byte) []byte {
return putCmsgStruct(
buf,
linux.SOL_IPV6,
@@ -348,6 +339,22 @@ func PackTClass(t *kernel.Task, tClass int32, buf []byte) []byte {
)
}
+// PackIPPacketInfo packs an IP_PKTINFO socket control message.
+func PackIPPacketInfo(t *kernel.Task, packetInfo tcpip.IPPacketInfo, buf []byte) []byte {
+ var p linux.ControlMessageIPPacketInfo
+ p.NIC = int32(packetInfo.NIC)
+ copy(p.LocalAddr[:], []byte(packetInfo.LocalAddr))
+ copy(p.DestinationAddr[:], []byte(packetInfo.DestinationAddr))
+
+ return putCmsgStruct(
+ buf,
+ linux.SOL_IP,
+ linux.IP_PKTINFO,
+ t.Arch().Width(),
+ p,
+ )
+}
+
// PackControlMessages packs control messages into the given buffer.
//
// We skip control messages specific to Unix domain sockets.
@@ -372,12 +379,16 @@ func PackControlMessages(t *kernel.Task, cmsgs socket.ControlMessages, buf []byt
buf = PackTClass(t, cmsgs.IP.TClass, buf)
}
+ if cmsgs.IP.HasIPPacketInfo {
+ buf = PackIPPacketInfo(t, cmsgs.IP.PacketInfo, buf)
+ }
+
return buf
}
// cmsgSpace is equivalent to CMSG_SPACE in Linux.
func cmsgSpace(t *kernel.Task, dataLen int) int {
- return linux.SizeOfControlMessageHeader + AlignUp(dataLen, t.Arch().Width())
+ return linux.SizeOfControlMessageHeader + binary.AlignUp(dataLen, t.Arch().Width())
}
// CmsgsSpace returns the number of bytes needed to fit the control messages
@@ -404,6 +415,16 @@ func CmsgsSpace(t *kernel.Task, cmsgs socket.ControlMessages) int {
return space
}
+// NewIPPacketInfo returns the IPPacketInfo struct.
+func NewIPPacketInfo(packetInfo linux.ControlMessageIPPacketInfo) tcpip.IPPacketInfo {
+ var p tcpip.IPPacketInfo
+ p.NIC = tcpip.NICID(packetInfo.NIC)
+ copy([]byte(p.LocalAddr), packetInfo.LocalAddr[:])
+ copy([]byte(p.DestinationAddr), packetInfo.DestinationAddr[:])
+
+ return p
+}
+
// Parse parses a raw socket control message into portable objects.
func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.ControlMessages, error) {
var (
@@ -437,7 +458,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
case linux.SOL_SOCKET:
switch h.Type {
case linux.SCM_RIGHTS:
- rightsSize := AlignDown(length, linux.SizeOfControlMessageRight)
+ rightsSize := binary.AlignDown(length, linux.SizeOfControlMessageRight)
numRights := rightsSize / linux.SizeOfControlMessageRight
if len(fds)+numRights > linux.SCM_MAX_FD {
@@ -448,7 +469,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
fds = append(fds, int32(usermem.ByteOrder.Uint32(buf[j:j+linux.SizeOfControlMessageRight])))
}
- i += AlignUp(length, width)
+ i += binary.AlignUp(length, width)
case linux.SCM_CREDENTIALS:
if length < linux.SizeOfControlMessageCredentials {
@@ -462,7 +483,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
return socket.ControlMessages{}, err
}
cmsgs.Unix.Credentials = scmCreds
- i += AlignUp(length, width)
+ i += binary.AlignUp(length, width)
default:
// Unknown message type.
@@ -476,7 +497,19 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
}
cmsgs.IP.HasTOS = true
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTOS], usermem.ByteOrder, &cmsgs.IP.TOS)
- i += AlignUp(length, width)
+ i += binary.AlignUp(length, width)
+
+ case linux.IP_PKTINFO:
+ if length < linux.SizeOfControlMessageIPPacketInfo {
+ return socket.ControlMessages{}, syserror.EINVAL
+ }
+
+ cmsgs.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+
+ cmsgs.IP.PacketInfo = NewIPPacketInfo(packetInfo)
+ i += binary.AlignUp(length, width)
default:
return socket.ControlMessages{}, syserror.EINVAL
@@ -489,7 +522,7 @@ func Parse(t *kernel.Task, socketOrEndpoint interface{}, buf []byte) (socket.Con
}
cmsgs.IP.HasTClass = true
binary.Unmarshal(buf[i:i+linux.SizeOfControlMessageTClass], usermem.ByteOrder, &cmsgs.IP.TClass)
- i += AlignUp(length, width)
+ i += binary.AlignUp(length, width)
default:
return socket.ControlMessages{}, syserror.EINVAL
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index de76388ac..22f78d2e2 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -289,7 +289,7 @@ func (s *socketOperations) GetSockOpt(t *kernel.Task, level int, name int, outPt
switch level {
case linux.SOL_IP:
switch name {
- case linux.IP_TOS, linux.IP_RECVTOS:
+ case linux.IP_TOS, linux.IP_RECVTOS, linux.IP_PKTINFO:
optlen = sizeofInt32
}
case linux.SOL_IPV6:
@@ -336,6 +336,8 @@ func (s *socketOperations) SetSockOpt(t *kernel.Task, level int, name int, opt [
switch name {
case linux.IP_TOS, linux.IP_RECVTOS:
optlen = sizeofInt32
+ case linux.IP_PKTINFO:
+ optlen = linux.SizeOfControlMessageIPPacketInfo
}
case linux.SOL_IPV6:
switch name {
@@ -473,7 +475,14 @@ func (s *socketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
case syscall.IP_TOS:
controlMessages.IP.HasTOS = true
binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageTOS], usermem.ByteOrder, &controlMessages.IP.TOS)
+
+ case syscall.IP_PKTINFO:
+ controlMessages.IP.HasIPPacketInfo = true
+ var packetInfo linux.ControlMessageIPPacketInfo
+ binary.Unmarshal(unixCmsg.Data[:linux.SizeOfControlMessageIPPacketInfo], usermem.ByteOrder, &packetInfo)
+ controlMessages.IP.PacketInfo = control.NewIPPacketInfo(packetInfo)
}
+
case syscall.SOL_IPV6:
switch unixCmsg.Header.Type {
case syscall.IPV6_TCLASS:
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index c91ec7494..721094bbf 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -7,6 +7,8 @@ go_library(
srcs = [
"extensions.go",
"netfilter.go",
+ "owner_matcher.go",
+ "targets.go",
"tcp_matcher.go",
"udp_matcher.go",
],
@@ -21,7 +23,6 @@ go_library(
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/usermem",
],
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index 22fd0ebe7..0336a32d8 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -19,7 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -37,12 +37,12 @@ type matchMaker interface {
// name is the matcher name as stored in the xt_entry_match struct.
name() string
- // marshal converts from an iptables.Matcher to an ABI struct.
- marshal(matcher iptables.Matcher) []byte
+ // marshal converts from an stack.Matcher to an ABI struct.
+ marshal(matcher stack.Matcher) []byte
// unmarshal converts from the ABI matcher struct to an
- // iptables.Matcher.
- unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error)
+ // stack.Matcher.
+ unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error)
}
// matchMakers maps the name of supported matchers to the matchMaker that
@@ -58,7 +58,7 @@ func registerMatchMaker(mm matchMaker) {
matchMakers[mm.name()] = mm
}
-func marshalMatcher(matcher iptables.Matcher) []byte {
+func marshalMatcher(matcher stack.Matcher) []byte {
matchMaker, ok := matchMakers[matcher.Name()]
if !ok {
panic(fmt.Sprintf("Unknown matcher of type %T.", matcher))
@@ -72,7 +72,7 @@ func marshalEntryMatch(name string, data []byte) []byte {
nflog("marshaling matcher %q", name)
// We have to pad this struct size to a multiple of 8 bytes.
- size := alignUp(linux.SizeOfXTEntryMatch+len(data), 8)
+ size := binary.AlignUp(linux.SizeOfXTEntryMatch+len(data), 8)
matcher := linux.KernelXTEntryMatch{
XTEntryMatch: linux.XTEntryMatch{
MatchSize: uint16(size),
@@ -86,15 +86,10 @@ func marshalEntryMatch(name string, data []byte) []byte {
return append(buf, make([]byte, size-len(buf))...)
}
-func unmarshalMatcher(match linux.XTEntryMatch, filter iptables.IPHeaderFilter, buf []byte) (iptables.Matcher, error) {
+func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) {
matchMaker, ok := matchMakers[match.Name.String()]
if !ok {
return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String())
}
return matchMaker.unmarshal(buf, filter)
}
-
-// alignUp rounds a length up to an alignment. align must be a power of 2.
-func alignUp(length int, align uint) int {
- return (length + int(align) - 1) & ^(int(align) - 1)
-}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index ea02627de..878f81fd5 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -26,7 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -35,6 +35,11 @@ import (
// shouldn't be reached - an error has occurred if we fall through to one.
const errorTargetName = "ERROR"
+// redirectTargetName is used to mark targets as redirect targets. Redirect
+// targets should be reached for only NAT and Mangle tables. These targets will
+// change the destination port/destination IP for packets.
+const redirectTargetName = "REDIRECT"
+
// Metadata is used to verify that we are correctly serializing and
// deserializing iptables into structs consumable by the iptables tool. We save
// a metadata struct when the tables are written, and when they are read out we
@@ -50,7 +55,9 @@ type metadata struct {
// nflog logs messages related to the writing and reading of iptables.
func nflog(format string, args ...interface{}) {
- log.Infof("netfilter: "+format, args...)
+ if log.IsLogging(log.Debug) {
+ log.Debugf("netfilter: "+format, args...)
+ }
}
// GetInfo returns information about iptables.
@@ -121,19 +128,19 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen
return entries, nil
}
-func findTable(stack *stack.Stack, tablename linux.TableName) (iptables.Table, error) {
- ipt := stack.IPTables()
+func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) {
+ ipt := stk.IPTables()
table, ok := ipt.Tables[tablename.String()]
if !ok {
- return iptables.Table{}, fmt.Errorf("couldn't find table %q", tablename)
+ return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename)
}
return table, nil
}
// FillDefaultIPTables sets stack's IPTables to the default tables and
// populates them with metadata.
-func FillDefaultIPTables(stack *stack.Stack) {
- ipt := iptables.DefaultTables()
+func FillDefaultIPTables(stk *stack.Stack) {
+ ipt := stack.DefaultTables()
// In order to fill in the metadata, we have to translate ipt from its
// netstack format to Linux's giant-binary-blob format.
@@ -146,14 +153,14 @@ func FillDefaultIPTables(stack *stack.Stack) {
ipt.Tables[name] = table
}
- stack.SetIPTables(ipt)
+ stk.SetIPTables(ipt)
}
// convertNetstackToBinary converts the iptables as stored in netstack to the
// format expected by the iptables tool. Linux stores each table as a binary
// blob that can only be traversed by parsing a bit, reading some offsets,
// jumping to those offsets, parsing again, etc.
-func convertNetstackToBinary(tablename string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, error) {
+func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelIPTGetEntries, metadata, error) {
// Return values.
var entries linux.KernelIPTGetEntries
var meta metadata
@@ -226,21 +233,29 @@ func convertNetstackToBinary(tablename string, table iptables.Table) (linux.Kern
return entries, meta, nil
}
-func marshalTarget(target iptables.Target) []byte {
- switch target.(type) {
- case iptables.UnconditionalAcceptTarget:
- return marshalStandardTarget(iptables.Accept)
- case iptables.UnconditionalDropTarget:
- return marshalStandardTarget(iptables.Drop)
- case iptables.ErrorTarget:
- return marshalErrorTarget()
+func marshalTarget(target stack.Target) []byte {
+ switch tg := target.(type) {
+ case stack.AcceptTarget:
+ return marshalStandardTarget(stack.RuleAccept)
+ case stack.DropTarget:
+ return marshalStandardTarget(stack.RuleDrop)
+ case stack.ErrorTarget:
+ return marshalErrorTarget(errorTargetName)
+ case stack.UserChainTarget:
+ return marshalErrorTarget(tg.Name)
+ case stack.ReturnTarget:
+ return marshalStandardTarget(stack.RuleReturn)
+ case stack.RedirectTarget:
+ return marshalRedirectTarget()
+ case JumpTarget:
+ return marshalJumpTarget(tg)
default:
panic(fmt.Errorf("unknown target of type %T", target))
}
}
-func marshalStandardTarget(verdict iptables.Verdict) []byte {
- nflog("convert to binary: marshalling standard target with size %d", linux.SizeOfXTStandardTarget)
+func marshalStandardTarget(verdict stack.RuleVerdict) []byte {
+ nflog("convert to binary: marshalling standard target")
// The target's name will be the empty string.
target := linux.XTStandardTarget{
@@ -254,60 +269,87 @@ func marshalStandardTarget(verdict iptables.Verdict) []byte {
return binary.Marshal(ret, usermem.ByteOrder, target)
}
-func marshalErrorTarget() []byte {
+func marshalErrorTarget(errorName string) []byte {
// This is an error target named error
target := linux.XTErrorTarget{
Target: linux.XTEntryTarget{
TargetSize: linux.SizeOfXTErrorTarget,
},
}
- copy(target.Name[:], errorTargetName)
+ copy(target.Name[:], errorName)
copy(target.Target.Name[:], errorTargetName)
ret := make([]byte, 0, linux.SizeOfXTErrorTarget)
return binary.Marshal(ret, usermem.ByteOrder, target)
}
+func marshalRedirectTarget() []byte {
+ // This is a redirect target named redirect
+ target := linux.XTRedirectTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTRedirectTarget,
+ },
+ }
+ copy(target.Target.Name[:], redirectTargetName)
+
+ ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
+func marshalJumpTarget(jt JumpTarget) []byte {
+ nflog("convert to binary: marshalling jump target")
+
+ // The target's name will be the empty string.
+ target := linux.XTStandardTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTStandardTarget,
+ },
+ // Verdict is overloaded by the ABI. When positive, it holds
+ // the jump offset from the start of the table.
+ Verdict: int32(jt.Offset),
+ }
+
+ ret := make([]byte, 0, linux.SizeOfXTStandardTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
// translateFromStandardVerdict translates verdicts the same way as the iptables
// tool.
-func translateFromStandardVerdict(verdict iptables.Verdict) int32 {
+func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 {
switch verdict {
- case iptables.Accept:
+ case stack.RuleAccept:
return -linux.NF_ACCEPT - 1
- case iptables.Drop:
+ case stack.RuleDrop:
return -linux.NF_DROP - 1
- case iptables.Queue:
- return -linux.NF_QUEUE - 1
- case iptables.Return:
+ case stack.RuleReturn:
return linux.NF_RETURN
- case iptables.Jump:
+ default:
// TODO(gvisor.dev/issue/170): Support Jump.
- panic("Jump isn't supported yet")
+ panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
}
- panic(fmt.Sprintf("unknown standard verdict: %d", verdict))
}
-// translateToStandardVerdict translates from the value in a
-// linux.XTStandardTarget to an iptables.Verdict.
-func translateToStandardVerdict(val int32) (iptables.Verdict, error) {
+// translateToStandardTarget translates from the value in a
+// linux.XTStandardTarget to an stack.Verdict.
+func translateToStandardTarget(val int32) (stack.Target, error) {
// TODO(gvisor.dev/issue/170): Support other verdicts.
switch val {
case -linux.NF_ACCEPT - 1:
- return iptables.Accept, nil
+ return stack.AcceptTarget{}, nil
case -linux.NF_DROP - 1:
- return iptables.Drop, nil
+ return stack.DropTarget{}, nil
case -linux.NF_QUEUE - 1:
- return iptables.Invalid, errors.New("unsupported iptables verdict QUEUE")
+ return nil, errors.New("unsupported iptables verdict QUEUE")
case linux.NF_RETURN:
- return iptables.Invalid, errors.New("unsupported iptables verdict RETURN")
+ return stack.ReturnTarget{}, nil
default:
- return iptables.Invalid, fmt.Errorf("unknown iptables verdict %d", val)
+ return nil, fmt.Errorf("unknown iptables verdict %d", val)
}
}
// SetEntries sets iptables rules for a single table. See
// net/ipv4/netfilter/ip_tables.c:translate_table for reference.
-func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
+func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// Get the basic rules data (struct ipt_replace).
if len(optVal) < linux.SizeOfIPTReplace {
nflog("optVal has insufficient size for replace %d", len(optVal))
@@ -319,10 +361,12 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace)
// TODO(gvisor.dev/issue/170): Support other tables.
- var table iptables.Table
+ var table stack.Table
switch replace.Name.String() {
- case iptables.TablenameFilter:
- table = iptables.EmptyFilterTable()
+ case stack.TablenameFilter:
+ table = stack.EmptyFilterTable()
+ case stack.TablenameNat:
+ table = stack.EmptyNatTable()
default:
nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
return syserr.ErrInvalidArgument
@@ -332,7 +376,8 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// Convert input into a list of rules and their offsets.
var offset uint32
- var offsets []uint32
+ // offsets maps rule byte offsets to their position in table.Rules.
+ offsets := map[uint32]int{}
for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
nflog("set entries: processing entry at offset %d", offset)
@@ -381,23 +426,24 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
return syserr.ErrInvalidArgument
}
- target, err := parseTarget(optVal[:targetSize])
+ target, err := parseTarget(filter, optVal[:targetSize])
if err != nil {
nflog("failed to parse target: %v", err)
return syserr.ErrInvalidArgument
}
optVal = optVal[targetSize:]
- table.Rules = append(table.Rules, iptables.Rule{
+ table.Rules = append(table.Rules, stack.Rule{
Filter: filter,
Target: target,
Matchers: matchers,
})
- offsets = append(offsets, offset)
+ offsets[offset] = int(entryIdx)
offset += uint32(entry.NextOffset)
if initialOptValLen-len(optVal) != int(entry.NextOffset) {
nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal))
+ return syserr.ErrInvalidArgument
}
}
@@ -406,31 +452,76 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
for hook, _ := range replace.HookEntry {
if table.ValidHooks()&(1<<hook) != 0 {
hk := hookFromLinux(hook)
- for ruleIdx, offset := range offsets {
+ for offset, ruleIdx := range offsets {
if offset == replace.HookEntry[hook] {
table.BuiltinChains[hk] = ruleIdx
}
if offset == replace.Underflow[hook] {
+ if !validUnderflow(table.Rules[ruleIdx]) {
+ nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP")
+ return syserr.ErrInvalidArgument
+ }
table.Underflows[hk] = ruleIdx
}
}
- if ruleIdx := table.BuiltinChains[hk]; ruleIdx == iptables.HookUnset {
+ if ruleIdx := table.BuiltinChains[hk]; ruleIdx == stack.HookUnset {
nflog("hook %v is unset.", hk)
return syserr.ErrInvalidArgument
}
- if ruleIdx := table.Underflows[hk]; ruleIdx == iptables.HookUnset {
+ if ruleIdx := table.Underflows[hk]; ruleIdx == stack.HookUnset {
nflog("underflow %v is unset.", hk)
return syserr.ErrInvalidArgument
}
}
}
+ // Add the user chains.
+ for ruleIdx, rule := range table.Rules {
+ target, ok := rule.Target.(stack.UserChainTarget)
+ if !ok {
+ continue
+ }
+
+ // We found a user chain. Before inserting it into the table,
+ // check that:
+ // - There's some other rule after it.
+ // - There are no matchers.
+ if ruleIdx == len(table.Rules)-1 {
+ nflog("user chain must have a rule or default policy")
+ return syserr.ErrInvalidArgument
+ }
+ if len(table.Rules[ruleIdx].Matchers) != 0 {
+ nflog("user chain's first node must have no matchers")
+ return syserr.ErrInvalidArgument
+ }
+ table.UserChains[target.Name] = ruleIdx + 1
+ }
+
+ // Set each jump to point to the appropriate rule. Right now they hold byte
+ // offsets.
+ for ruleIdx, rule := range table.Rules {
+ jump, ok := rule.Target.(JumpTarget)
+ if !ok {
+ continue
+ }
+
+ // Find the rule corresponding to the jump rule offset.
+ jumpTo, ok := offsets[jump.Offset]
+ if !ok {
+ nflog("failed to find a rule to jump to")
+ return syserr.ErrInvalidArgument
+ }
+ jump.RuleNum = jumpTo
+ rule.Target = jump
+ table.Rules[ruleIdx] = rule
+ }
+
// TODO(gvisor.dev/issue/170): Support other chains.
- // Since we only support modifying the INPUT chain right now, make sure
- // all other chains point to ACCEPT rules.
+ // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now,
+ // make sure all other chains point to ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
- if hook != iptables.Input {
- if _, ok := table.Rules[ruleIdx].Target.(iptables.UnconditionalAcceptTarget); !ok {
+ if hook == stack.Forward || hook == stack.Postrouting {
+ if _, ok := table.Rules[ruleIdx].Target.(stack.AcceptTarget); !ok {
nflog("hook %d is unsupported.", hook)
return syserr.ErrInvalidArgument
}
@@ -442,7 +533,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
- ipt := stack.IPTables()
+ ipt := stk.IPTables()
table.SetMetadata(metadata{
HookEntry: replace.HookEntry,
Underflow: replace.Underflow,
@@ -450,16 +541,16 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
Size: replace.Size,
})
ipt.Tables[replace.Name.String()] = table
- stack.SetIPTables(ipt)
+ stk.SetIPTables(ipt)
return nil
}
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
// only the matchers.
-func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Matcher, error) {
+func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) {
nflog("set entries: parsing matchers of size %d", len(optVal))
- var matchers []iptables.Matcher
+ var matchers []stack.Matcher
for len(optVal) > 0 {
nflog("set entries: optVal has len %d", len(optVal))
@@ -501,7 +592,7 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma
// parseTarget parses a target from optVal. optVal should contain only the
// target.
-func parseTarget(optVal []byte) (iptables.Target, error) {
+func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) {
nflog("set entries: parsing target of size %d", len(optVal))
if len(optVal) < linux.SizeOfXTEntryTarget {
return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal))
@@ -519,18 +610,12 @@ func parseTarget(optVal []byte) (iptables.Target, error) {
buf = optVal[:linux.SizeOfXTStandardTarget]
binary.Unmarshal(buf, usermem.ByteOrder, &standardTarget)
- verdict, err := translateToStandardVerdict(standardTarget.Verdict)
- if err != nil {
- return nil, err
- }
- switch verdict {
- case iptables.Accept:
- return iptables.UnconditionalAcceptTarget{}, nil
- case iptables.Drop:
- return iptables.UnconditionalDropTarget{}, nil
- default:
- return nil, fmt.Errorf("Unknown verdict: %v", verdict)
+ if standardTarget.Verdict < 0 {
+ // A Verdict < 0 indicates a non-jump verdict.
+ return translateToStandardTarget(standardTarget.Verdict)
}
+ // A verdict >= 0 indicates a jump.
+ return JumpTarget{Offset: uint32(standardTarget.Verdict)}, nil
case errorTargetName:
// Error target.
@@ -548,55 +633,128 @@ func parseTarget(optVal []byte) (iptables.Target, error) {
// somehow fall through every rule.
// * To mark the start of a user defined chain. These
// rules have an error with the name of the chain.
- switch errorTarget.Name.String() {
+ switch name := errorTarget.Name.String(); name {
case errorTargetName:
- return iptables.ErrorTarget{}, nil
+ nflog("set entries: error target")
+ return stack.ErrorTarget{}, nil
default:
- return nil, fmt.Errorf("unknown error target %q doesn't exist or isn't supported yet.", errorTarget.Name.String())
+ // User defined chain.
+ nflog("set entries: user-defined target %q", name)
+ return stack.UserChainTarget{Name: name}, nil
+ }
+
+ case redirectTargetName:
+ // Redirect target.
+ if len(optVal) < linux.SizeOfXTRedirectTarget {
+ return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal))
}
+
+ if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ var redirectTarget linux.XTRedirectTarget
+ buf = optVal[:linux.SizeOfXTRedirectTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget)
+
+ // Copy linux.XTRedirectTarget to stack.RedirectTarget.
+ var target stack.RedirectTarget
+ nfRange := redirectTarget.NfRange
+
+ // RangeSize should be 1.
+ if nfRange.RangeSize != 1 {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ // TODO(gvisor.dev/issue/170): Check if the flags are valid.
+ // Also check if we need to map ports or IP.
+ // For now, redirect target only supports destination port change.
+ // Port range and IP range are not supported yet.
+ if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+ target.RangeProtoSpecified = true
+
+ target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:])
+
+ // TODO(gvisor.dev/issue/170): Port range is not supported yet.
+ if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ // Convert port from big endian to little endian.
+ port := make([]byte, 2)
+ binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort)
+ target.MinPort = binary.LittleEndian.Uint16(port)
+
+ binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort)
+ target.MaxPort = binary.LittleEndian.Uint16(port)
+ return target, nil
}
// Unknown target.
return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet.", target.Name.String())
}
-func filterFromIPTIP(iptip linux.IPTIP) (iptables.IPHeaderFilter, error) {
+func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
if containsUnsupportedFields(iptip) {
- return iptables.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
+ return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
+ }
+ if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask))
}
- return iptables.IPHeaderFilter{
- Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ return stack.IPHeaderFilter{
+ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ Dst: tcpip.Address(iptip.Dst[:]),
+ DstMask: tcpip.Address(iptip.DstMask[:]),
+ DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0,
}, nil
}
func containsUnsupportedFields(iptip linux.IPTIP) bool {
- // Currently we check that everything except protocol is zeroed.
+ // The following features are supported:
+ // - Protocol
+ // - Dst and DstMask
+ // - The inverse destination IP check flag
var emptyInetAddr = linux.InetAddr{}
var emptyInterface = [linux.IFNAMSIZ]byte{}
- return iptip.Dst != emptyInetAddr ||
- iptip.Src != emptyInetAddr ||
+ // Disable any supported inverse flags.
+ inverseMask := uint8(linux.IPT_INV_DSTIP)
+ return iptip.Src != emptyInetAddr ||
iptip.SrcMask != emptyInetAddr ||
- iptip.DstMask != emptyInetAddr ||
iptip.InputInterface != emptyInterface ||
iptip.OutputInterface != emptyInterface ||
iptip.InputInterfaceMask != emptyInterface ||
iptip.OutputInterfaceMask != emptyInterface ||
iptip.Flags != 0 ||
- iptip.InverseFlags != 0
+ iptip.InverseFlags&^inverseMask != 0
+}
+
+func validUnderflow(rule stack.Rule) bool {
+ if len(rule.Matchers) != 0 {
+ return false
+ }
+ switch rule.Target.(type) {
+ case stack.AcceptTarget, stack.DropTarget:
+ return true
+ default:
+ return false
+ }
}
-func hookFromLinux(hook int) iptables.Hook {
+func hookFromLinux(hook int) stack.Hook {
switch hook {
case linux.NF_INET_PRE_ROUTING:
- return iptables.Prerouting
+ return stack.Prerouting
case linux.NF_INET_LOCAL_IN:
- return iptables.Input
+ return stack.Input
case linux.NF_INET_FORWARD:
- return iptables.Forward
+ return stack.Forward
case linux.NF_INET_LOCAL_OUT:
- return iptables.Output
+ return stack.Output
case linux.NF_INET_POST_ROUTING:
- return iptables.Postrouting
+ return stack.Postrouting
}
panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook))
}
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
new file mode 100644
index 000000000..5949a7c29
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -0,0 +1,128 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const matcherNameOwner = "owner"
+
+func init() {
+ registerMatchMaker(ownerMarshaler{})
+}
+
+// ownerMarshaler implements matchMaker for owner matching.
+type ownerMarshaler struct{}
+
+// name implements matchMaker.name.
+func (ownerMarshaler) name() string {
+ return matcherNameOwner
+}
+
+// marshal implements matchMaker.marshal.
+func (ownerMarshaler) marshal(mr stack.Matcher) []byte {
+ matcher := mr.(*OwnerMatcher)
+ iptOwnerInfo := linux.IPTOwnerInfo{
+ UID: matcher.uid,
+ GID: matcher.gid,
+ }
+
+ // Support for UID match.
+ // TODO(gvisor.dev/issue/170): Need to support gid match.
+ if matcher.matchUID {
+ iptOwnerInfo.Match = linux.XT_OWNER_UID
+ } else if matcher.matchGID {
+ panic("GID match is not supported.")
+ } else {
+ panic("UID match is not set.")
+ }
+
+ buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo)
+ return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, usermem.ByteOrder, iptOwnerInfo))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+ if len(buf) < linux.SizeOfIPTOwnerInfo {
+ return nil, fmt.Errorf("buf has insufficient size for owner match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may
+ // exceed what's strictly necessary to hold matchData.
+ var matchData linux.IPTOwnerInfo
+ binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData)
+
+ if matchData.Invert != 0 {
+ return nil, fmt.Errorf("invert flag is not supported for owner match")
+ }
+
+ // Support for UID match.
+ // TODO(gvisor.dev/issue/170): Need to support gid match.
+ if matchData.Match&linux.XT_OWNER_UID != linux.XT_OWNER_UID {
+ return nil, fmt.Errorf("owner match is only supported for uid")
+ }
+
+ // Check Flags.
+ var owner OwnerMatcher
+ owner.uid = matchData.UID
+ owner.gid = matchData.GID
+ owner.matchUID = true
+
+ return &owner, nil
+}
+
+type OwnerMatcher struct {
+ uid uint32
+ gid uint32
+ matchUID bool
+ matchGID bool
+ invert uint8
+}
+
+// Name implements Matcher.Name.
+func (*OwnerMatcher) Name() string {
+ return matcherNameOwner
+}
+
+// Match implements Matcher.Match.
+func (om *OwnerMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
+ // Support only for OUTPUT chain.
+ // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also.
+ if hook != stack.Output {
+ return false, true
+ }
+
+ // If the packet owner is not set, drop the packet.
+ // Support for uid match.
+ // TODO(gvisor.dev/issue/170): Need to support gid match.
+ if pkt.Owner == nil || !om.matchUID {
+ return false, true
+ }
+
+ // TODO(gvisor.dev/issue/170): Need to add tests to verify
+ // drop rule when packet UID does not match owner matcher UID.
+ if pkt.Owner.UID() != om.uid {
+ return false, false
+ }
+
+ return true, false
+}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
new file mode 100644
index 000000000..c948de876
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -0,0 +1,34 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
+
+// JumpTarget implements stack.Target.
+type JumpTarget struct {
+ // Offset is the byte offset of the rule to jump to. It is used for
+ // marshaling and unmarshaling.
+ Offset uint32
+
+ // RuleNum is the rule to jump to.
+ RuleNum int
+}
+
+// Action implements stack.Target.Action.
+func (jt JumpTarget) Action(stack.PacketBuffer) (stack.RuleVerdict, int) {
+ return stack.RuleJump, jt.RuleNum
+}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index f9945e214..ff1cfd8f6 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -19,9 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -40,7 +39,7 @@ func (tcpMarshaler) name() string {
}
// marshal implements matchMaker.marshal.
-func (tcpMarshaler) marshal(mr iptables.Matcher) []byte {
+func (tcpMarshaler) marshal(mr stack.Matcher) []byte {
matcher := mr.(*TCPMatcher)
xttcp := linux.XTTCP{
SourcePortStart: matcher.sourcePortStart,
@@ -53,7 +52,7 @@ func (tcpMarshaler) marshal(mr iptables.Matcher) []byte {
}
// unmarshal implements matchMaker.unmarshal.
-func (tcpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) {
+func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
if len(buf) < linux.SizeOfXTTCP {
return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf))
}
@@ -97,7 +96,7 @@ func (*TCPMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (tm *TCPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) {
+func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
netHeader := header.IPv4(pkt.NetworkHeader)
if netHeader.TransportProtocol() != header.TCPProtocolNumber {
@@ -115,7 +114,7 @@ func (tm *TCPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfac
// Now we need the transport header. However, this may not have been set
// yet.
// TODO(gvisor.dev/issue/170): Parsing the transport header should
- // ultimately be moved into the iptables.Check codepath as matchers are
+ // ultimately be moved into the stack.Check codepath as matchers are
// added.
var tcpHeader header.TCP
if pkt.TransportHeader != nil {
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index 86aa11696..3359418c1 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -19,9 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -40,7 +39,7 @@ func (udpMarshaler) name() string {
}
// marshal implements matchMaker.marshal.
-func (udpMarshaler) marshal(mr iptables.Matcher) []byte {
+func (udpMarshaler) marshal(mr stack.Matcher) []byte {
matcher := mr.(*UDPMatcher)
xtudp := linux.XTUDP{
SourcePortStart: matcher.sourcePortStart,
@@ -53,7 +52,7 @@ func (udpMarshaler) marshal(mr iptables.Matcher) []byte {
}
// unmarshal implements matchMaker.unmarshal.
-func (udpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) {
+func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
if len(buf) < linux.SizeOfXTUDP {
return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf))
}
@@ -94,11 +93,11 @@ func (*UDPMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (um *UDPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) {
+func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
netHeader := header.IPv4(pkt.NetworkHeader)
// TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
- // into the iptables.Check codepath as matchers are added.
+ // into the stack.Check codepath as matchers are added.
if netHeader.TransportProtocol() != header.UDPProtocolNumber {
return false, false
}
@@ -114,7 +113,7 @@ func (um *UDPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfac
// Now we need the transport header. However, this may not have been set
// yet.
// TODO(gvisor.dev/issue/170): Parsing the transport header should
- // ultimately be moved into the iptables.Check codepath as matchers are
+ // ultimately be moved into the stack.Check codepath as matchers are
// added.
var udpHeader header.UDP
if pkt.TransportHeader != nil {
diff --git a/pkg/sentry/socket/netlink/message.go b/pkg/sentry/socket/netlink/message.go
index 4ea252ccb..0899c61d1 100644
--- a/pkg/sentry/socket/netlink/message.go
+++ b/pkg/sentry/socket/netlink/message.go
@@ -23,18 +23,11 @@ import (
"gvisor.dev/gvisor/pkg/usermem"
)
-// alignUp rounds a length up to an alignment.
-//
-// Preconditions: align is a power of two.
-func alignUp(length int, align uint) int {
- return (length + int(align) - 1) &^ (int(align) - 1)
-}
-
// alignPad returns the length of padding required for alignment.
//
// Preconditions: align is a power of two.
func alignPad(length int, align uint) int {
- return alignUp(length, align) - length
+ return binary.AlignUp(length, align) - length
}
// Message contains a complete serialized netlink message.
@@ -138,7 +131,7 @@ func (m *Message) Finalize() []byte {
// Align the message. Note that the message length in the header (set
// above) is the useful length of the message, not the total aligned
// length. See net/netlink/af_netlink.c:__nlmsg_put.
- aligned := alignUp(len(m.buf), linux.NLMSG_ALIGNTO)
+ aligned := binary.AlignUp(len(m.buf), linux.NLMSG_ALIGNTO)
m.putZeros(aligned - len(m.buf))
return m.buf
}
@@ -173,7 +166,7 @@ func (m *Message) PutAttr(atype uint16, v interface{}) {
m.Put(v)
// Align the attribute.
- aligned := alignUp(l, linux.NLA_ALIGNTO)
+ aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
m.putZeros(aligned - l)
}
@@ -190,7 +183,7 @@ func (m *Message) PutAttrString(atype uint16, s string) {
m.putZeros(1)
// Align the attribute.
- aligned := alignUp(l, linux.NLA_ALIGNTO)
+ aligned := binary.AlignUp(l, linux.NLA_ALIGNTO)
m.putZeros(aligned - l)
}
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index ab01cb4fa..cbf46b1e9 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -38,7 +38,6 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index ed2fbcceb..d5879c10f 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -29,6 +29,7 @@ import (
"io"
"math"
"reflect"
+ "sync/atomic"
"syscall"
"time"
@@ -62,7 +63,13 @@ import (
func mustCreateMetric(name, description string) *tcpip.StatCounter {
var cm tcpip.StatCounter
- metric.MustRegisterCustomUint64Metric(name, false /* sync */, description, cm.Value)
+ metric.MustRegisterCustomUint64Metric(name, true /* cumulative */, false /* sync */, description, cm.Value)
+ return &cm
+}
+
+func mustCreateGauge(name, description string) *tcpip.StatCounter {
+ var cm tcpip.StatCounter
+ metric.MustRegisterCustomUint64Metric(name, false /* cumulative */, false /* sync */, description, cm.Value)
return &cm
}
@@ -150,10 +157,10 @@ var Metrics = tcpip.Stats{
TCP: tcpip.TCPStats{
ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."),
PassiveConnectionOpenings: mustCreateMetric("/netstack/tcp/passive_connection_openings", "Number of connections opened successfully via Listen."),
- CurrentEstablished: mustCreateMetric("/netstack/tcp/current_established", "Number of connections in ESTABLISHED state now."),
- CurrentConnected: mustCreateMetric("/netstack/tcp/current_open", "Number of connections that are in connected state."),
+ CurrentEstablished: mustCreateGauge("/netstack/tcp/current_established", "Number of connections in ESTABLISHED state now."),
+ CurrentConnected: mustCreateGauge("/netstack/tcp/current_open", "Number of connections that are in connected state."),
EstablishedResets: mustCreateMetric("/netstack/tcp/established_resets", "Number of times TCP connections have made a direct transition to the CLOSED state from either the ESTABLISHED state or the CLOSE-WAIT state"),
- EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "number of times established TCP connections made a transition to CLOSED state."),
+ EstablishedClosed: mustCreateMetric("/netstack/tcp/established_closed", "Number of times established TCP connections made a transition to CLOSED state."),
EstablishedTimedout: mustCreateMetric("/netstack/tcp/established_timedout", "Number of times an established connection was reset because of keep-alive time out."),
ListenOverflowSynDrop: mustCreateMetric("/netstack/tcp/listen_overflow_syn_drop", "Number of times the listen queue overflowed and a SYN was dropped."),
ListenOverflowAckDrop: mustCreateMetric("/netstack/tcp/listen_overflow_ack_drop", "Number of times the listen queue overflowed and the final ACK in the handshake was dropped."),
@@ -264,6 +271,12 @@ type SocketOperations struct {
skType linux.SockType
protocol int
+ // readViewHasData is 1 iff readView has data to be read, 0 otherwise.
+ // Must be accessed using atomic operations. It must only be written
+ // with readMu held but can be read without holding readMu. The latter
+ // is required to avoid deadlocks in epoll Readiness checks.
+ readViewHasData uint32
+
// readMu protects access to the below fields.
readMu sync.Mutex `state:"nosave"`
// readView contains the remaining payload from the last packet.
@@ -293,7 +306,7 @@ type SocketOperations struct {
// New creates a new endpoint socket.
func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil {
+ if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
}
@@ -410,21 +423,24 @@ func (s *SocketOperations) isPacketBased() bool {
// fetchReadView updates the readView field of the socket if it's currently
// empty. It assumes that the socket is locked.
+//
+// Precondition: s.readMu must be held.
func (s *SocketOperations) fetchReadView() *syserr.Error {
if len(s.readView) > 0 {
return nil
}
-
s.readView = nil
s.sender = tcpip.FullAddress{}
v, cms, err := s.Endpoint.Read(&s.sender)
if err != nil {
+ atomic.StoreUint32(&s.readViewHasData, 0)
return syserr.TranslateNetstackError(err)
}
s.readView = v
s.readCM = cms
+ atomic.StoreUint32(&s.readViewHasData, 1)
return nil
}
@@ -525,7 +541,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
}
if resCh != nil {
- t := ctx.(*kernel.Task)
+ t := kernel.TaskFromContext(ctx)
if err := t.Block(resCh); err != nil {
return 0, syserr.FromError(err).ToError()
}
@@ -598,7 +614,7 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader
}
if resCh != nil {
- t := ctx.(*kernel.Task)
+ t := kernel.TaskFromContext(ctx)
if err := t.Block(resCh); err != nil {
return 0, syserr.FromError(err).ToError()
}
@@ -623,11 +639,9 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
// Check our cached value iff the caller asked for readability and the
// endpoint itself is currently not readable.
if (mask & ^r & waiter.EventIn) != 0 {
- s.readMu.Lock()
- if len(s.readView) > 0 {
+ if atomic.LoadUint32(&s.readViewHasData) == 1 {
r |= waiter.EventIn
}
- s.readMu.Unlock()
}
return r
@@ -655,7 +669,7 @@ func (s *SocketOperations) checkFamily(family uint16, exact bool) *syserr.Error
// This is a hack to work around the fact that both IPv4 and IPv6 ANY are
// represented by the empty string.
//
-// TODO(gvisor.dev/issues/1556): remove this function.
+// TODO(gvisor.dev/issue/1556): remove this function.
func (s *SocketOperations) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress {
if len(addr.Addr) == 0 && s.family == linux.AF_INET6 && family == linux.AF_INET {
addr.Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
@@ -712,14 +726,44 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- addr, family, err := AddressAndFamily(sockaddr)
- if err != nil {
- return err
+ if len(sockaddr) < 2 {
+ return syserr.ErrInvalidArgument
}
- if err := s.checkFamily(family, true /* exact */); err != nil {
- return err
+
+ family := usermem.ByteOrder.Uint16(sockaddr)
+ var addr tcpip.FullAddress
+
+ // Bind for AF_PACKET requires only family, protocol and ifindex.
+ // In function AddressAndFamily, we check the address length which is
+ // not needed for AF_PACKET bind.
+ if family == linux.AF_PACKET {
+ var a linux.SockAddrLink
+ if len(sockaddr) < sockAddrLinkSize {
+ return syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(sockaddr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+
+ if a.Protocol != uint16(s.protocol) {
+ return syserr.ErrInvalidArgument
+ }
+
+ addr = tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }
+ } else {
+ var err *syserr.Error
+ addr, family, err = AddressAndFamily(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ if err = s.checkFamily(family, true /* exact */); err != nil {
+ return err
+ }
+
+ addr = s.mapFamily(addr, family)
}
- addr = s.mapFamily(addr, family)
// Issue the bind request to the endpoint.
return syserr.TranslateNetstackError(s.Endpoint.Bind(addr))
@@ -902,7 +946,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
// sockets backed by a commonEndpoint.
-func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
@@ -927,8 +971,15 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int,
return nil, syserr.ErrProtocolNotAvailable
}
+func boolToInt32(v bool) int32 {
+ if v {
+ return 1
+ }
+ return 0
+}
+
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_ERROR:
@@ -960,12 +1011,11 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.PasscredOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.PasscredOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_SNDBUF:
if outLen < sizeOfInt32 {
@@ -1004,24 +1054,22 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.ReuseAddressOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_REUSEPORT:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.ReusePortOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.ReusePortOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_BINDTODEVICE:
var v tcpip.BindToDeviceOption
@@ -1051,24 +1099,22 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.BroadcastOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.BroadcastOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_KEEPALIVE:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.KeepaliveEnabledOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_LINGER:
if outLen < linux.SizeOfLinger {
@@ -1118,47 +1164,41 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptInt(tcpip.DelayOption)
+ v, err := ep.GetSockOptBool(tcpip.DelayOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- if v == 0 {
- return int32(1), nil
- }
- return int32(0), nil
+ return boolToInt32(!v), nil
case linux.TCP_CORK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.CorkOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.CorkOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.TCP_QUICKACK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.QuickAckOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.QuickAckOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.TCP_MAXSEG:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.MaxSegOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.MaxSegOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1290,11 +1330,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- var o int32
- if v {
- o = 1
- }
- return o, nil
+ return boolToInt32(v), nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1304,8 +1340,8 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if outLen == 0 {
return make([]byte, 0), nil
}
- var v tcpip.IPv6TrafficClassOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1318,6 +1354,17 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
}
return ib, nil
+ case linux.IPV6_RECVTCLASS:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveTClassOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ return boolToInt32(v), nil
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -1332,8 +1379,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.TTLOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.TTLOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1349,8 +1396,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.MulticastTTLOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.MulticastTTLOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1375,23 +1422,19 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.MulticastLoopOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- if v {
- return int32(1), nil
- }
- return int32(0), nil
+ return boolToInt32(v), nil
case linux.IP_TOS:
// Length handling for parity with Linux.
if outLen == 0 {
return []byte(nil), nil
}
- var v tcpip.IPv4TOSOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
if outLen < sizeOfInt32 {
@@ -1408,11 +1451,18 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- var o int32
- if v {
- o = 1
+ return boolToInt32(v), nil
+
+ case linux.IP_PKTINFO:
+ if outLen < sizeOfInt32 {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ v, err := ep.GetSockOptBool(tcpip.ReceiveIPPacketInfoOption)
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
}
- return o, nil
+ return boolToInt32(v), nil
default:
emitUnimplementedEventIP(t, name)
@@ -1472,7 +1522,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
// SetSockOpt can be used to implement the linux syscall setsockopt(2) for
// sockets backed by a commonEndpoint.
-func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error {
+func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error {
switch level {
case linux.SOL_SOCKET:
return setSockOptSocket(t, s, ep, name, optVal)
@@ -1499,7 +1549,7 @@ func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, n
}
// setSockOptSocket implements SetSockOpt when level is SOL_SOCKET.
-func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
switch name {
case linux.SO_SNDBUF:
if len(optVal) < sizeOfInt32 {
@@ -1523,7 +1573,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReuseAddressOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0))
case linux.SO_REUSEPORT:
if len(optVal) < sizeOfInt32 {
@@ -1531,7 +1581,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0))
case linux.SO_BINDTODEVICE:
n := bytes.IndexByte(optVal, 0)
@@ -1559,7 +1609,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BroadcastOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.BroadcastOption, v != 0))
case linux.SO_PASSCRED:
if len(optVal) < sizeOfInt32 {
@@ -1567,7 +1617,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0))
case linux.SO_KEEPALIVE:
if len(optVal) < sizeOfInt32 {
@@ -1575,7 +1625,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveEnabledOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0))
case linux.SO_SNDTIMEO:
if len(optVal) < linux.SizeOfTimeval {
@@ -1647,11 +1697,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- var o int
- if v == 0 {
- o = 1
- }
- return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0))
case linux.TCP_CORK:
if len(optVal) < sizeOfInt32 {
@@ -1659,7 +1705,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.CorkOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0))
case linux.TCP_QUICKACK:
if len(optVal) < sizeOfInt32 {
@@ -1667,7 +1713,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0))
case linux.TCP_MAXSEG:
if len(optVal) < sizeOfInt32 {
@@ -1675,7 +1721,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MaxSegOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MaxSegOption, int(v)))
case linux.TCP_KEEPIDLE:
if len(optVal) < sizeOfInt32 {
@@ -1762,6 +1808,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
linux.IPV6_IPSEC_POLICY,
linux.IPV6_JOIN_ANYCAST,
linux.IPV6_LEAVE_ANYCAST,
+ // TODO(b/148887420): Add support for IPV6_PKTINFO.
linux.IPV6_PKTINFO,
linux.IPV6_ROUTER_ALERT,
linux.IPV6_XFRM_POLICY,
@@ -1785,7 +1832,15 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
if v == -1 {
v = 0
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, int(v)))
+
+ case linux.IPV6_RECVTCLASS:
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0))
default:
emitUnimplementedEventIPv6(t, name)
@@ -1862,7 +1917,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
if v < 0 || v > 255 {
return syserr.ErrInvalidArgument
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastTTLOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MulticastTTLOption, int(v)))
case linux.IP_ADD_MEMBERSHIP:
req, err := copyInMulticastRequest(optVal, false /* allowAddr */)
@@ -1909,9 +1964,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(
- tcpip.MulticastLoopOption(v != 0),
- ))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0))
case linux.MCAST_JOIN_GROUP:
// FIXME(b/124219304): Implement MCAST_JOIN_GROUP.
@@ -1930,7 +1983,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
} else if v < 1 || v > 255 {
return syserr.ErrInvalidArgument
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TTLOption, int(v)))
case linux.IP_TOS:
if len(optVal) == 0 {
@@ -1940,7 +1993,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv4TOSOption, int(v)))
case linux.IP_RECVTOS:
v, err := parseIntOrChar(optVal)
@@ -1949,6 +2002,16 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
}
return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTOSOption, v != 0))
+ case linux.IP_PKTINFO:
+ if len(optVal) == 0 {
+ return nil
+ }
+ v, err := parseIntOrChar(optVal)
+ if err != nil {
+ return err
+ }
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveIPPacketInfoOption, v != 0))
+
case linux.IP_ADD_SOURCE_MEMBERSHIP,
linux.IP_BIND_ADDRESS_NO_PORT,
linux.IP_BLOCK_SOURCE,
@@ -1964,7 +2027,6 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
linux.IP_NODEFRAG,
linux.IP_OPTIONS,
linux.IP_PASSSEC,
- linux.IP_PKTINFO,
linux.IP_RECVERR,
linux.IP_RECVFRAGSIZE,
linux.IP_RECVOPTS,
@@ -2061,7 +2123,6 @@ func emitUnimplementedEventIPv6(t *kernel.Task, name int) {
linux.IPV6_RECVPATHMTU,
linux.IPV6_RECVPKTINFO,
linux.IPV6_RECVRTHDR,
- linux.IPV6_RECVTCLASS,
linux.IPV6_RTHDR,
linux.IPV6_RTHDRDSTOPTS,
linux.IPV6_TCLASS,
@@ -2256,6 +2317,10 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq
}
copied += n
s.readView.TrimFront(n)
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
+ }
+
dst = dst.DropFirst(n)
if e != nil {
err = syserr.FromError(e)
@@ -2302,9 +2367,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
// caller-supplied buffer.
s.readMu.Lock()
n, err := s.coalescingRead(ctx, dst, trunc)
- s.readMu.Unlock()
cmsg := s.controlMessages()
s.fillCmsgInq(&cmsg)
+ s.readMu.Unlock()
return n, 0, nil, 0, cmsg, err
}
@@ -2378,6 +2443,10 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
s.readView.TrimFront(int(n))
}
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
+ }
+
var flags int
if msgLen > int(n) {
flags |= linux.MSG_TRUNC
@@ -2395,10 +2464,14 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
func (s *SocketOperations) controlMessages() socket.ControlMessages {
return socket.ControlMessages{
IP: tcpip.ControlMessages{
- HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
- Timestamp: s.readCM.Timestamp,
- HasTOS: s.readCM.HasTOS,
- TOS: s.readCM.TOS,
+ HasTimestamp: s.readCM.HasTimestamp && s.sockOptTimestamp,
+ Timestamp: s.readCM.Timestamp,
+ HasTOS: s.readCM.HasTOS,
+ TOS: s.readCM.TOS,
+ HasTClass: s.readCM.HasTClass,
+ TClass: s.readCM.TClass,
+ HasIPPacketInfo: s.readCM.HasIPPacketInfo,
+ PacketInfo: s.readCM.PacketInfo,
},
}
}
@@ -2585,7 +2658,9 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
}
// Add bytes removed from the endpoint but not yet sent to the caller.
+ s.readMu.Lock()
v += len(s.readView)
+ s.readMu.Unlock()
if v > math.MaxInt32 {
v = math.MaxInt32
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index 5afff2564..c3f04b613 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -62,10 +62,6 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
}
case linux.SOCK_RAW:
- // TODO(b/142504697): "In order to create a raw socket, a
- // process must have the CAP_NET_RAW capability in the user
- // namespace that governs its network namespace." - raw(7)
-
// Raw sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(ctx)
if !creds.HasCapability(linux.CAP_NET_RAW) {
@@ -75,6 +71,8 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
switch protocol {
case syscall.IPPROTO_ICMP:
return header.ICMPv4ProtocolNumber, true, nil
+ case syscall.IPPROTO_ICMPV6:
+ return header.ICMPv6ProtocolNumber, true, nil
case syscall.IPPROTO_UDP:
return header.UDPProtocolNumber, true, nil
case syscall.IPPROTO_TCP:
@@ -124,6 +122,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated)
} else {
ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+
+ // Assign task to PacketOwner interface to get the UID and GID for
+ // iptables owner matching.
+ if e == nil {
+ ep.SetOwner(t)
+ }
}
if e != nil {
return nil, syserr.TranslateNetstackError(e)
@@ -133,10 +137,6 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
}
func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
- // TODO(b/142504697): "In order to create a packet socket, a process
- // must have the CAP_NET_RAW capability in the user namespace that
- // governs its network namespace." - packet(7)
-
// Packet sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(t)
if !creds.HasCapability(linux.CAP_NET_RAW) {
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 0692482e9..f5fa18136 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -200,36 +199,66 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
// Statistics implements inet.Stack.Statistics.
func (s *Stack) Statistics(stat interface{}, arg string) error {
switch stats := stat.(type) {
+ case *inet.StatDev:
+ for _, ni := range s.Stack.NICInfo() {
+ if ni.Name != arg {
+ continue
+ }
+ // TODO(gvisor.dev/issue/2103) Support stubbed stats.
+ *stats = inet.StatDev{
+ // Receive section.
+ ni.Stats.Rx.Bytes.Value(), // bytes.
+ ni.Stats.Rx.Packets.Value(), // packets.
+ 0, // errs.
+ 0, // drop.
+ 0, // fifo.
+ 0, // frame.
+ 0, // compressed.
+ 0, // multicast.
+ // Transmit section.
+ ni.Stats.Tx.Bytes.Value(), // bytes.
+ ni.Stats.Tx.Packets.Value(), // packets.
+ 0, // errs.
+ 0, // drop.
+ 0, // fifo.
+ 0, // colls.
+ 0, // carrier.
+ 0, // compressed.
+ }
+ break
+ }
case *inet.StatSNMPIP:
ip := Metrics.IP
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPIP{
- 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL.
+ 0, // Ip/Forwarding.
+ 0, // Ip/DefaultTTL.
ip.PacketsReceived.Value(), // InReceives.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors.
+ 0, // Ip/InHdrErrors.
ip.InvalidDestinationAddressesReceived.Value(), // InAddrErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards.
+ 0, // Ip/ForwDatagrams.
+ 0, // Ip/InUnknownProtos.
+ 0, // Ip/InDiscards.
ip.PacketsDelivered.Value(), // InDelivers.
ip.PacketsSent.Value(), // OutRequests.
ip.OutgoingPacketErrors.Value(), // OutDiscards.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates.
+ 0, // Ip/OutNoRoutes.
+ 0, // Support Ip/ReasmTimeout.
+ 0, // Support Ip/ReasmReqds.
+ 0, // Support Ip/ReasmOKs.
+ 0, // Support Ip/ReasmFails.
+ 0, // Support Ip/FragOKs.
+ 0, // Support Ip/FragFails.
+ 0, // Support Ip/FragCreates.
}
case *inet.StatSNMPICMP:
in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats
out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPICMP{
- 0, // TODO(gvisor.dev/issue/969): Support Icmp/InMsgs.
+ 0, // Icmp/InMsgs.
Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Icmp/InCsumErrors.
+ 0, // Icmp/InCsumErrors.
in.DstUnreachable.Value(), // InDestUnreachs.
in.TimeExceeded.Value(), // InTimeExcds.
in.ParamProblem.Value(), // InParmProbs.
@@ -241,7 +270,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
in.TimestampReply.Value(), // InTimestampReps.
in.InfoRequest.Value(), // InAddrMasks.
in.InfoReply.Value(), // InAddrMaskReps.
- 0, // TODO(gvisor.dev/issue/969): Support Icmp/OutMsgs.
+ 0, // Icmp/OutMsgs.
Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors.
out.DstUnreachable.Value(), // OutDestUnreachs.
out.TimeExceeded.Value(), // OutTimeExcds.
@@ -277,15 +306,16 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
}
case *inet.StatSNMPUDP:
udp := Metrics.UDP
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPUDP{
udp.PacketsReceived.Value(), // InDatagrams.
udp.UnknownPortErrors.Value(), // NoPorts.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/InErrors.
+ 0, // Udp/InErrors.
udp.PacketsSent.Value(), // OutDatagrams.
udp.ReceiveBufferErrors.Value(), // RcvbufErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/SndbufErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/InCsumErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/IgnoredMulti.
+ 0, // Udp/SndbufErrors.
+ 0, // Udp/InCsumErrors.
+ 0, // Udp/IgnoredMulti.
}
default:
return syserr.ErrEndpointOperation.ToError()
@@ -332,7 +362,7 @@ func (s *Stack) RouteTable() []inet.Route {
}
// IPTables returns the stack's iptables.
-func (s *Stack) IPTables() (iptables.IPTables, error) {
+func (s *Stack) IPTables() (stack.IPTables, error) {
return s.Stack.IPTables(), nil
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 50d9744e6..6580bd6e9 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
@@ -48,11 +49,25 @@ func (c *ControlMessages) Release() {
c.Unix.Release()
}
-// Socket is the interface containing socket syscalls used by the syscall layer
-// to redirect them to the appropriate implementation.
+// Socket is an interface combining fs.FileOperations and SocketOps,
+// representing a VFS1 socket file.
type Socket interface {
fs.FileOperations
+ SocketOps
+}
+
+// SocketVFS2 is an interface combining vfs.FileDescription and SocketOps,
+// representing a VFS2 socket file.
+type SocketVFS2 interface {
+ vfs.FileDescriptionImpl
+ SocketOps
+}
+// SocketOps is the interface containing socket syscalls used by the syscall
+// layer to redirect them to the appropriate implementation.
+//
+// It is implemented by both Socket and SocketVFS2.
+type SocketOps interface {
// Connect implements the connect(2) linux syscall.
Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error
@@ -153,6 +168,8 @@ var families = make(map[int][]Provider)
// RegisterProvider registers the provider of a given address family so that
// sockets of that type can be created via socket() and/or socketpair()
// syscalls.
+//
+// This should only be called during the initialization of the address family.
func RegisterProvider(family int, provider Provider) {
families[family] = append(families[family], provider)
}
@@ -216,6 +233,74 @@ func NewDirent(ctx context.Context, d *device.Device) *fs.Dirent {
return fs.NewDirent(ctx, inode, fmt.Sprintf("socket:[%d]", ino))
}
+// ProviderVFS2 is the vfs2 interface implemented by providers of sockets for
+// specific address families (e.g., AF_INET).
+type ProviderVFS2 interface {
+ // Socket creates a new socket.
+ //
+ // If a nil Socket _and_ a nil error is returned, it means that the
+ // protocol is not supported. A non-nil error should only be returned
+ // if the protocol is supported, but an error occurs during creation.
+ Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error)
+
+ // Pair creates a pair of connected sockets.
+ //
+ // See Socket for error information.
+ Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error)
+}
+
+// familiesVFS2 holds a map of all known address families and their providers.
+var familiesVFS2 = make(map[int][]ProviderVFS2)
+
+// RegisterProviderVFS2 registers the provider of a given address family so that
+// sockets of that type can be created via socket() and/or socketpair()
+// syscalls.
+//
+// This should only be called during the initialization of the address family.
+func RegisterProviderVFS2(family int, provider ProviderVFS2) {
+ familiesVFS2[family] = append(familiesVFS2[family], provider)
+}
+
+// NewVFS2 creates a new socket with the given family, type and protocol.
+func NewVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ for _, p := range familiesVFS2[family] {
+ s, err := p.Socket(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+ if s != nil {
+ t.Kernel().RecordSocketVFS2(s)
+ return s, nil
+ }
+ }
+
+ return nil, syserr.ErrAddressFamilyNotSupported
+}
+
+// PairVFS2 creates a new connected socket pair with the given family, type and
+// protocol.
+func PairVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ providers, ok := familiesVFS2[family]
+ if !ok {
+ return nil, nil, syserr.ErrAddressFamilyNotSupported
+ }
+
+ for _, p := range providers {
+ s1, s2, err := p.Pair(t, stype, protocol)
+ if err != nil {
+ return nil, nil, err
+ }
+ if s1 != nil && s2 != nil {
+ k := t.Kernel()
+ k.RecordSocketVFS2(s1)
+ k.RecordSocketVFS2(s2)
+ return s1, s2, nil
+ }
+ }
+
+ return nil, nil, syserr.ErrSocketNotSupported
+}
+
// SendReceiveTimeout stores timeouts for send and receive calls.
//
// It is meant to be embedded into Socket implementations to help satisfy the
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index 08743deba..de2cc4bdf 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -8,23 +8,27 @@ go_library(
"device.go",
"io.go",
"unix.go",
+ "unix_vfs2.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/fspath",
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
"//pkg/sentry/socket",
"//pkg/sentry/socket/control",
"//pkg/sentry/socket/netstack",
"//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 74bcd6300..c708b6030 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/ilist",
+ "//pkg/log",
"//pkg/refs",
"//pkg/sync",
"//pkg/syserr",
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 2ef654235..2f1b127df 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -838,24 +839,43 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
// SetSockOpt sets a socket option. Currently not supported.
func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
- case tcpip.PasscredOption:
- e.setPasscred(v != 0)
- return nil
- }
return nil
}
func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.BroadcastOption:
+ case tcpip.PasscredOption:
+ e.setPasscred(v)
+ case tcpip.ReuseAddressOption:
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ }
return nil
}
func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ }
return nil
}
func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ case tcpip.PasscredOption:
+ return e.Passcred(), nil
+
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ return false, tcpip.ErrUnknownProtocolOption
+ }
}
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
@@ -914,29 +934,19 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return int(v), nil
default:
+ log.Warningf("Unsupported socket option: %d", opt)
return -1, tcpip.ErrUnknownProtocolOption
}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
+ switch opt.(type) {
case tcpip.ErrorOption:
return nil
- case *tcpip.PasscredOption:
- if e.Passcred() {
- *o = tcpip.PasscredOption(1)
- } else {
- *o = tcpip.PasscredOption(0)
- }
- return nil
-
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
default:
+ log.Warningf("Unsupported socket option: %T", opt)
return tcpip.ErrUnknownProtocolOption
}
}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 4d30aa714..7c64f30fa 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -33,6 +34,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/control"
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -52,11 +54,8 @@ type SocketOperations struct {
fsutil.FileNoSplice `state:"nosave"`
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- refs.AtomicRefCount
- socket.SendReceiveTimeout
- ep transport.Endpoint
- stype linux.SockType
+ socketOpsCommon
}
// New creates a new unix socket.
@@ -75,16 +74,29 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
}
s := SocketOperations{
- ep: ep,
- stype: stype,
+ socketOpsCommon: socketOpsCommon{
+ ep: ep,
+ stype: stype,
+ },
}
s.EnableLeakCheck("unix.SocketOperations")
return fs.NewFile(ctx, d, flags, &s)
}
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
+ refs.AtomicRefCount
+ socket.SendReceiveTimeout
+
+ ep transport.Endpoint
+ stype linux.SockType
+}
+
// DecRef implements RefCounter.DecRef.
-func (s *SocketOperations) DecRef() {
+func (s *socketOpsCommon) DecRef() {
s.DecRefWithDestructor(func() {
s.ep.Close()
})
@@ -97,7 +109,7 @@ func (s *SocketOperations) Release() {
s.DecRef()
}
-func (s *SocketOperations) isPacket() bool {
+func (s *socketOpsCommon) isPacket() bool {
switch s.stype {
case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
return true
@@ -110,7 +122,7 @@ func (s *SocketOperations) isPacket() bool {
}
// Endpoint extracts the transport.Endpoint.
-func (s *SocketOperations) Endpoint() transport.Endpoint {
+func (s *socketOpsCommon) Endpoint() transport.Endpoint {
return s.ep
}
@@ -143,7 +155,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) {
// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.ep.GetRemoteAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -155,7 +167,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32,
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.ep.GetLocalAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -178,7 +190,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
// Listen implements the linux syscall listen(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
return s.ep.Listen(backlog)
}
@@ -310,6 +322,8 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
// Create the socket.
+ //
+ // TODO(gvisor.dev/issue/2324): Correctly set file permissions.
childDir, err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}})
if err != nil {
return syserr.ErrPortInUse
@@ -345,6 +359,31 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint,
return ep, nil
}
+ if kernel.VFS2Enabled {
+ p := fspath.Parse(path)
+ root := t.FSContext().RootDirectoryVFS2()
+ start := root
+ relPath := !p.Absolute
+ if relPath {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ }
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: p,
+ FollowFinalSymlink: true,
+ }
+ ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop)
+ root.DecRef()
+ if relPath {
+ start.DecRef()
+ }
+ if e != nil {
+ return nil, syserr.FromError(e)
+ }
+ return ep, nil
+ }
+
// Find the node in the filesystem.
root := t.FSContext().RootDirectory()
cwd := t.FSContext().WorkingDirectory()
@@ -363,12 +402,11 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint,
// No socket!
return nil, syserr.ErrConnectionRefused
}
-
return ep, nil
}
// Connect implements the linux syscall connect(2) for unix sockets.
-func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
ep, err := extractEndpoint(t, sockaddr)
if err != nil {
return err
@@ -379,7 +417,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
return s.ep.Connect(t, ep)
}
-// Writev implements fs.FileOperations.Write.
+// Write implements fs.FileOperations.Write.
func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
t := kernel.TaskFromContext(ctx)
ctrl := control.New(t, s.ep, nil)
@@ -399,7 +437,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
// SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
w := EndpointWriter{
Ctx: t,
Endpoint: s.ep,
@@ -453,27 +491,27 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
}
// Passcred implements transport.Credentialer.Passcred.
-func (s *SocketOperations) Passcred() bool {
+func (s *socketOpsCommon) Passcred() bool {
return s.ep.Passcred()
}
// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
-func (s *SocketOperations) ConnectedPasscred() bool {
+func (s *socketOpsCommon) ConnectedPasscred() bool {
return s.ep.ConnectedPasscred()
}
// Readiness implements waiter.Waitable.Readiness.
-func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.ep.Readiness(mask)
}
// EventRegister implements waiter.Waitable.EventRegister.
-func (s *SocketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
s.ep.EventRegister(e, mask)
}
// EventUnregister implements waiter.Waitable.EventUnregister.
-func (s *SocketOperations) EventUnregister(e *waiter.Entry) {
+func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
s.ep.EventUnregister(e)
}
@@ -485,7 +523,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
// Shutdown implements the linux syscall shutdown(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
f, err := netstack.ConvertShutdown(how)
if err != nil {
return err
@@ -511,7 +549,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -648,12 +686,12 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
// State implements socket.Socket.State.
-func (s *SocketOperations) State() uint32 {
+func (s *socketOpsCommon) State() uint32 {
return s.ep.State()
}
// Type implements socket.Socket.Type.
-func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) {
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
// Unix domain sockets always have a protocol of 0.
return linux.AF_UNIX, s.stype, 0
}
@@ -706,4 +744,5 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F
func init() {
socket.RegisterProvider(linux.AF_UNIX, &provider{})
+ socket.RegisterProviderVFS2(linux.AF_UNIX, &providerVFS2{})
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
new file mode 100644
index 000000000..3e54d49c4
--- /dev/null
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -0,0 +1,348 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SocketVFS2 implements socket.SocketVFS2 (and by extension,
+// vfs.FileDescriptionImpl) for Unix sockets.
+type SocketVFS2 struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+
+ socketOpsCommon
+}
+
+// NewVFS2File creates and returns a new vfs.FileDescription for a unix socket.
+func NewVFS2File(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
+ sock := NewFDImpl(ep, stype)
+ vfsfd := &sock.vfsfd
+ if err := sockfs.InitSocket(sock, vfsfd, t.Kernel().SocketMount(), t.Credentials()); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return vfsfd, nil
+}
+
+// NewFDImpl creates and returns a new SocketVFS2.
+func NewFDImpl(ep transport.Endpoint, stype linux.SockType) *SocketVFS2 {
+ // You can create AF_UNIX, SOCK_RAW sockets. They're the same as
+ // SOCK_DGRAM and don't require CAP_NET_RAW.
+ if stype == linux.SOCK_RAW {
+ stype = linux.SOCK_DGRAM
+ }
+
+ return &SocketVFS2{
+ socketOpsCommon: socketOpsCommon{
+ ep: ep,
+ stype: stype,
+ },
+ }
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+ return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
+}
+
+// blockingAccept implements a blocking version of accept(2), that is, if no
+// connections are ready to be accept, it will block until one becomes ready.
+func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.socketOpsCommon.EventRegister(&e, waiter.EventIn)
+ defer s.socketOpsCommon.EventUnregister(&e)
+
+ // Try to accept the connection; if it fails, then wait until we get a
+ // notification.
+ for {
+ if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ return ep, err
+ }
+
+ if err := t.Block(ch); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ }
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, err := s.ep.Accept()
+ if err != nil {
+ if err != syserr.ErrWouldBlock || !blocking {
+ return 0, nil, 0, err
+ }
+
+ var err *syserr.Error
+ ep, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ // We expect this to be a FileDescription here.
+ ns, err := NewVFS2File(t, ep, s.stype)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ defer ns.DecRef()
+
+ if flags&linux.SOCK_NONBLOCK != 0 {
+ ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK)
+ }
+
+ var addr linux.SockAddr
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer.
+ var err *syserr.Error
+ addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ })
+ if e != nil {
+ return 0, nil, 0, syserr.FromError(e)
+ }
+
+ t.Kernel().RecordSocketVFS2(ns)
+ return fd, addr, addrLen, nil
+}
+
+// Bind implements the linux syscall bind(2) for unix sockets.
+func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ p, e := extractPath(sockaddr)
+ if e != nil {
+ return e
+ }
+
+ bep, ok := s.ep.(transport.BoundEndpoint)
+ if !ok {
+ // This socket can't be bound.
+ return syserr.ErrInvalidArgument
+ }
+
+ return s.ep.Bind(tcpip.FullAddress{Addr: tcpip.Address(p)}, func() *syserr.Error {
+ // Is it abstract?
+ if p[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return syserr.ErrInvalidEndpointState
+ }
+ if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil {
+ // syserr.ErrPortInUse corresponds to EADDRINUSE.
+ return syserr.ErrPortInUse
+ }
+ } else {
+ path := fspath.Parse(p)
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ relPath := !path.Absolute
+ if relPath {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ }
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ }
+ err := t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{
+ // TODO(gvisor.dev/issue/2324): The file permissions should be taken
+ // from s and t.FSContext().Umask() (see net/unix/af_unix.c:unix_bind),
+ // but VFS1 just always uses 0400. Resolve this inconsistency.
+ Mode: linux.S_IFSOCK | 0400,
+ Endpoint: bep,
+ })
+ if err == syserror.EEXIST {
+ return syserr.ErrAddressInUse
+ }
+ return syserr.FromError(err)
+ }
+
+ return nil
+ })
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return netstack.Ioctl(ctx, s.ep, uio, args)
+}
+
+// PRead implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Read implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/1476): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &EndpointReader{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ NumRights: 0,
+ Peek: false,
+ From: nil,
+ })
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/1476): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ t := kernel.TaskFromContext(ctx)
+ ctrl := control.New(t, s.ep, nil)
+
+ if src.NumBytes() == 0 {
+ nInt, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil)
+ return int64(nInt), err.ToError()
+ }
+
+ return src.CopyInTo(ctx, &EndpointWriter{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ Control: ctrl,
+ To: nil,
+ })
+}
+
+// Release implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Release() {
+ // Release only decrements a reference on s because s may be referenced in
+ // the abstract socket namespace.
+ s.DecRef()
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.socketOpsCommon.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.socketOpsCommon.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SocketVFS2) EventUnregister(e *waiter.Entry) {
+ s.socketOpsCommon.EventUnregister(e)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ return netstack.SetSockOpt(t, s, s.ep, level, name, optVal)
+}
+
+// providerVFS2 is a unix domain socket provider for VFS2.
+type providerVFS2 struct{}
+
+func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // Create the endpoint and socket.
+ var ep transport.Endpoint
+ switch stype {
+ case linux.SOCK_DGRAM, linux.SOCK_RAW:
+ ep = transport.NewConnectionless(t)
+ case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
+ ep = transport.NewConnectioned(t, stype, t.Kernel())
+ default:
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ f, err := NewVFS2File(t, ep, stype)
+ if err != nil {
+ ep.Close()
+ return nil, err
+ }
+ return f, nil
+}
+
+// Pair creates a new pair of AF_UNIX connected sockets.
+func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, nil, syserr.ErrProtocolNotSupported
+ }
+
+ switch stype {
+ case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET, linux.SOCK_RAW:
+ // Ok
+ default:
+ return nil, nil, syserr.ErrInvalidArgument
+ }
+
+ // Create the endpoints and sockets.
+ ep1, ep2 := transport.NewPair(t, stype, t.Kernel())
+ s1, err := NewVFS2File(t, ep1, stype)
+ if err != nil {
+ ep1.Close()
+ ep2.Close()
+ return nil, nil, err
+ }
+ s2, err := NewVFS2File(t, ep2, stype)
+ if err != nil {
+ s1.DecRef()
+ ep2.Close()
+ return nil, nil, err
+ }
+
+ return s1, s2, nil
+}