summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/BUILD1
-rw-r--r--pkg/sentry/socket/netfilter/BUILD2
-rw-r--r--pkg/sentry/socket/netfilter/extensions.go14
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go208
-rw-r--r--pkg/sentry/socket/netfilter/owner_matcher.go128
-rw-r--r--pkg/sentry/socket/netfilter/targets.go11
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go11
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go13
-rw-r--r--pkg/sentry/socket/netstack/BUILD1
-rw-r--r--pkg/sentry/socket/netstack/netstack.go239
-rw-r--r--pkg/sentry/socket/netstack/provider.go14
-rw-r--r--pkg/sentry/socket/netstack/stack.go76
-rw-r--r--pkg/sentry/socket/socket.go89
-rw-r--r--pkg/sentry/socket/unix/BUILD4
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD1
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go48
-rw-r--r--pkg/sentry/socket/unix/unix.go89
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go348
18 files changed, 1017 insertions, 280 deletions
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index 611fa22c3..c40c6d673 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -16,6 +16,7 @@ go_library(
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
"//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/usermem",
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index 7cd2ce55b..721094bbf 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"extensions.go",
"netfilter.go",
+ "owner_matcher.go",
"targets.go",
"tcp_matcher.go",
"udp_matcher.go",
@@ -22,7 +23,6 @@ go_library(
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/stack",
"//pkg/usermem",
],
diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go
index b4b244abf..0336a32d8 100644
--- a/pkg/sentry/socket/netfilter/extensions.go
+++ b/pkg/sentry/socket/netfilter/extensions.go
@@ -19,7 +19,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -37,12 +37,12 @@ type matchMaker interface {
// name is the matcher name as stored in the xt_entry_match struct.
name() string
- // marshal converts from an iptables.Matcher to an ABI struct.
- marshal(matcher iptables.Matcher) []byte
+ // marshal converts from an stack.Matcher to an ABI struct.
+ marshal(matcher stack.Matcher) []byte
// unmarshal converts from the ABI matcher struct to an
- // iptables.Matcher.
- unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error)
+ // stack.Matcher.
+ unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error)
}
// matchMakers maps the name of supported matchers to the matchMaker that
@@ -58,7 +58,7 @@ func registerMatchMaker(mm matchMaker) {
matchMakers[mm.name()] = mm
}
-func marshalMatcher(matcher iptables.Matcher) []byte {
+func marshalMatcher(matcher stack.Matcher) []byte {
matchMaker, ok := matchMakers[matcher.Name()]
if !ok {
panic(fmt.Sprintf("Unknown matcher of type %T.", matcher))
@@ -86,7 +86,7 @@ func marshalEntryMatch(name string, data []byte) []byte {
return append(buf, make([]byte, size-len(buf))...)
}
-func unmarshalMatcher(match linux.XTEntryMatch, filter iptables.IPHeaderFilter, buf []byte) (iptables.Matcher, error) {
+func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf []byte) (stack.Matcher, error) {
matchMaker, ok := matchMakers[match.Name.String()]
if !ok {
return nil, fmt.Errorf("unsupported matcher with name %q", match.Name.String())
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index 2ec11f6ac..878f81fd5 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -26,7 +26,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -35,6 +35,11 @@ import (
// shouldn't be reached - an error has occurred if we fall through to one.
const errorTargetName = "ERROR"
+// redirectTargetName is used to mark targets as redirect targets. Redirect
+// targets should be reached for only NAT and Mangle tables. These targets will
+// change the destination port/destination IP for packets.
+const redirectTargetName = "REDIRECT"
+
// Metadata is used to verify that we are correctly serializing and
// deserializing iptables into structs consumable by the iptables tool. We save
// a metadata struct when the tables are written, and when they are read out we
@@ -123,19 +128,19 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen
return entries, nil
}
-func findTable(stack *stack.Stack, tablename linux.TableName) (iptables.Table, error) {
- ipt := stack.IPTables()
+func findTable(stk *stack.Stack, tablename linux.TableName) (stack.Table, error) {
+ ipt := stk.IPTables()
table, ok := ipt.Tables[tablename.String()]
if !ok {
- return iptables.Table{}, fmt.Errorf("couldn't find table %q", tablename)
+ return stack.Table{}, fmt.Errorf("couldn't find table %q", tablename)
}
return table, nil
}
// FillDefaultIPTables sets stack's IPTables to the default tables and
// populates them with metadata.
-func FillDefaultIPTables(stack *stack.Stack) {
- ipt := iptables.DefaultTables()
+func FillDefaultIPTables(stk *stack.Stack) {
+ ipt := stack.DefaultTables()
// In order to fill in the metadata, we have to translate ipt from its
// netstack format to Linux's giant-binary-blob format.
@@ -148,14 +153,14 @@ func FillDefaultIPTables(stack *stack.Stack) {
ipt.Tables[name] = table
}
- stack.SetIPTables(ipt)
+ stk.SetIPTables(ipt)
}
// convertNetstackToBinary converts the iptables as stored in netstack to the
// format expected by the iptables tool. Linux stores each table as a binary
// blob that can only be traversed by parsing a bit, reading some offsets,
// jumping to those offsets, parsing again, etc.
-func convertNetstackToBinary(tablename string, table iptables.Table) (linux.KernelIPTGetEntries, metadata, error) {
+func convertNetstackToBinary(tablename string, table stack.Table) (linux.KernelIPTGetEntries, metadata, error) {
// Return values.
var entries linux.KernelIPTGetEntries
var meta metadata
@@ -228,18 +233,20 @@ func convertNetstackToBinary(tablename string, table iptables.Table) (linux.Kern
return entries, meta, nil
}
-func marshalTarget(target iptables.Target) []byte {
+func marshalTarget(target stack.Target) []byte {
switch tg := target.(type) {
- case iptables.AcceptTarget:
- return marshalStandardTarget(iptables.RuleAccept)
- case iptables.DropTarget:
- return marshalStandardTarget(iptables.RuleDrop)
- case iptables.ErrorTarget:
+ case stack.AcceptTarget:
+ return marshalStandardTarget(stack.RuleAccept)
+ case stack.DropTarget:
+ return marshalStandardTarget(stack.RuleDrop)
+ case stack.ErrorTarget:
return marshalErrorTarget(errorTargetName)
- case iptables.UserChainTarget:
+ case stack.UserChainTarget:
return marshalErrorTarget(tg.Name)
- case iptables.ReturnTarget:
- return marshalStandardTarget(iptables.RuleReturn)
+ case stack.ReturnTarget:
+ return marshalStandardTarget(stack.RuleReturn)
+ case stack.RedirectTarget:
+ return marshalRedirectTarget()
case JumpTarget:
return marshalJumpTarget(tg)
default:
@@ -247,7 +254,7 @@ func marshalTarget(target iptables.Target) []byte {
}
}
-func marshalStandardTarget(verdict iptables.RuleVerdict) []byte {
+func marshalStandardTarget(verdict stack.RuleVerdict) []byte {
nflog("convert to binary: marshalling standard target")
// The target's name will be the empty string.
@@ -276,6 +283,19 @@ func marshalErrorTarget(errorName string) []byte {
return binary.Marshal(ret, usermem.ByteOrder, target)
}
+func marshalRedirectTarget() []byte {
+ // This is a redirect target named redirect
+ target := linux.XTRedirectTarget{
+ Target: linux.XTEntryTarget{
+ TargetSize: linux.SizeOfXTRedirectTarget,
+ },
+ }
+ copy(target.Target.Name[:], redirectTargetName)
+
+ ret := make([]byte, 0, linux.SizeOfXTRedirectTarget)
+ return binary.Marshal(ret, usermem.ByteOrder, target)
+}
+
func marshalJumpTarget(jt JumpTarget) []byte {
nflog("convert to binary: marshalling jump target")
@@ -295,13 +315,13 @@ func marshalJumpTarget(jt JumpTarget) []byte {
// translateFromStandardVerdict translates verdicts the same way as the iptables
// tool.
-func translateFromStandardVerdict(verdict iptables.RuleVerdict) int32 {
+func translateFromStandardVerdict(verdict stack.RuleVerdict) int32 {
switch verdict {
- case iptables.RuleAccept:
+ case stack.RuleAccept:
return -linux.NF_ACCEPT - 1
- case iptables.RuleDrop:
+ case stack.RuleDrop:
return -linux.NF_DROP - 1
- case iptables.RuleReturn:
+ case stack.RuleReturn:
return linux.NF_RETURN
default:
// TODO(gvisor.dev/issue/170): Support Jump.
@@ -310,18 +330,18 @@ func translateFromStandardVerdict(verdict iptables.RuleVerdict) int32 {
}
// translateToStandardTarget translates from the value in a
-// linux.XTStandardTarget to an iptables.Verdict.
-func translateToStandardTarget(val int32) (iptables.Target, error) {
+// linux.XTStandardTarget to an stack.Verdict.
+func translateToStandardTarget(val int32) (stack.Target, error) {
// TODO(gvisor.dev/issue/170): Support other verdicts.
switch val {
case -linux.NF_ACCEPT - 1:
- return iptables.AcceptTarget{}, nil
+ return stack.AcceptTarget{}, nil
case -linux.NF_DROP - 1:
- return iptables.DropTarget{}, nil
+ return stack.DropTarget{}, nil
case -linux.NF_QUEUE - 1:
return nil, errors.New("unsupported iptables verdict QUEUE")
case linux.NF_RETURN:
- return iptables.ReturnTarget{}, nil
+ return stack.ReturnTarget{}, nil
default:
return nil, fmt.Errorf("unknown iptables verdict %d", val)
}
@@ -329,7 +349,7 @@ func translateToStandardTarget(val int32) (iptables.Target, error) {
// SetEntries sets iptables rules for a single table. See
// net/ipv4/netfilter/ip_tables.c:translate_table for reference.
-func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
+func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// Get the basic rules data (struct ipt_replace).
if len(optVal) < linux.SizeOfIPTReplace {
nflog("optVal has insufficient size for replace %d", len(optVal))
@@ -341,10 +361,12 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
binary.Unmarshal(replaceBuf, usermem.ByteOrder, &replace)
// TODO(gvisor.dev/issue/170): Support other tables.
- var table iptables.Table
+ var table stack.Table
switch replace.Name.String() {
- case iptables.TablenameFilter:
- table = iptables.EmptyFilterTable()
+ case stack.TablenameFilter:
+ table = stack.EmptyFilterTable()
+ case stack.TablenameNat:
+ table = stack.EmptyNatTable()
default:
nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
return syserr.ErrInvalidArgument
@@ -404,14 +426,14 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
return syserr.ErrInvalidArgument
}
- target, err := parseTarget(optVal[:targetSize])
+ target, err := parseTarget(filter, optVal[:targetSize])
if err != nil {
nflog("failed to parse target: %v", err)
return syserr.ErrInvalidArgument
}
optVal = optVal[targetSize:]
- table.Rules = append(table.Rules, iptables.Rule{
+ table.Rules = append(table.Rules, stack.Rule{
Filter: filter,
Target: target,
Matchers: matchers,
@@ -442,11 +464,11 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
table.Underflows[hk] = ruleIdx
}
}
- if ruleIdx := table.BuiltinChains[hk]; ruleIdx == iptables.HookUnset {
+ if ruleIdx := table.BuiltinChains[hk]; ruleIdx == stack.HookUnset {
nflog("hook %v is unset.", hk)
return syserr.ErrInvalidArgument
}
- if ruleIdx := table.Underflows[hk]; ruleIdx == iptables.HookUnset {
+ if ruleIdx := table.Underflows[hk]; ruleIdx == stack.HookUnset {
nflog("underflow %v is unset.", hk)
return syserr.ErrInvalidArgument
}
@@ -455,7 +477,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// Add the user chains.
for ruleIdx, rule := range table.Rules {
- target, ok := rule.Target.(iptables.UserChainTarget)
+ target, ok := rule.Target.(stack.UserChainTarget)
if !ok {
continue
}
@@ -495,11 +517,11 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
}
// TODO(gvisor.dev/issue/170): Support other chains.
- // Since we only support modifying the INPUT chain right now, make sure
- // all other chains point to ACCEPT rules.
+ // Since we only support modifying the INPUT, PREROUTING and OUTPUT chain right now,
+ // make sure all other chains point to ACCEPT rules.
for hook, ruleIdx := range table.BuiltinChains {
- if hook != iptables.Input {
- if _, ok := table.Rules[ruleIdx].Target.(iptables.AcceptTarget); !ok {
+ if hook == stack.Forward || hook == stack.Postrouting {
+ if _, ok := table.Rules[ruleIdx].Target.(stack.AcceptTarget); !ok {
nflog("hook %d is unsupported.", hook)
return syserr.ErrInvalidArgument
}
@@ -511,7 +533,7 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
- ipt := stack.IPTables()
+ ipt := stk.IPTables()
table.SetMetadata(metadata{
HookEntry: replace.HookEntry,
Underflow: replace.Underflow,
@@ -519,16 +541,16 @@ func SetEntries(stack *stack.Stack, optVal []byte) *syserr.Error {
Size: replace.Size,
})
ipt.Tables[replace.Name.String()] = table
- stack.SetIPTables(ipt)
+ stk.SetIPTables(ipt)
return nil
}
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
// only the matchers.
-func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Matcher, error) {
+func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, error) {
nflog("set entries: parsing matchers of size %d", len(optVal))
- var matchers []iptables.Matcher
+ var matchers []stack.Matcher
for len(optVal) > 0 {
nflog("set entries: optVal has len %d", len(optVal))
@@ -570,7 +592,7 @@ func parseMatchers(filter iptables.IPHeaderFilter, optVal []byte) ([]iptables.Ma
// parseTarget parses a target from optVal. optVal should contain only the
// target.
-func parseTarget(optVal []byte) (iptables.Target, error) {
+func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, error) {
nflog("set entries: parsing target of size %d", len(optVal))
if len(optVal) < linux.SizeOfXTEntryTarget {
return nil, fmt.Errorf("optVal has insufficient size for entry target %d", len(optVal))
@@ -614,67 +636,125 @@ func parseTarget(optVal []byte) (iptables.Target, error) {
switch name := errorTarget.Name.String(); name {
case errorTargetName:
nflog("set entries: error target")
- return iptables.ErrorTarget{}, nil
+ return stack.ErrorTarget{}, nil
default:
// User defined chain.
nflog("set entries: user-defined target %q", name)
- return iptables.UserChainTarget{Name: name}, nil
+ return stack.UserChainTarget{Name: name}, nil
+ }
+
+ case redirectTargetName:
+ // Redirect target.
+ if len(optVal) < linux.SizeOfXTRedirectTarget {
+ return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal))
+ }
+
+ if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
}
+
+ var redirectTarget linux.XTRedirectTarget
+ buf = optVal[:linux.SizeOfXTRedirectTarget]
+ binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget)
+
+ // Copy linux.XTRedirectTarget to stack.RedirectTarget.
+ var target stack.RedirectTarget
+ nfRange := redirectTarget.NfRange
+
+ // RangeSize should be 1.
+ if nfRange.RangeSize != 1 {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ // TODO(gvisor.dev/issue/170): Check if the flags are valid.
+ // Also check if we need to map ports or IP.
+ // For now, redirect target only supports destination port change.
+ // Port range and IP range are not supported yet.
+ if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+ target.RangeProtoSpecified = true
+
+ target.MinIP = tcpip.Address(nfRange.RangeIPV4.MinIP[:])
+ target.MaxIP = tcpip.Address(nfRange.RangeIPV4.MaxIP[:])
+
+ // TODO(gvisor.dev/issue/170): Port range is not supported yet.
+ if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ }
+
+ // Convert port from big endian to little endian.
+ port := make([]byte, 2)
+ binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MinPort)
+ target.MinPort = binary.LittleEndian.Uint16(port)
+
+ binary.BigEndian.PutUint16(port, nfRange.RangeIPV4.MaxPort)
+ target.MaxPort = binary.LittleEndian.Uint16(port)
+ return target, nil
}
// Unknown target.
return nil, fmt.Errorf("unknown target %q doesn't exist or isn't supported yet.", target.Name.String())
}
-func filterFromIPTIP(iptip linux.IPTIP) (iptables.IPHeaderFilter, error) {
+func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
if containsUnsupportedFields(iptip) {
- return iptables.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
+ return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
+ }
+ if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask))
}
- return iptables.IPHeaderFilter{
- Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ return stack.IPHeaderFilter{
+ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ Dst: tcpip.Address(iptip.Dst[:]),
+ DstMask: tcpip.Address(iptip.DstMask[:]),
+ DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0,
}, nil
}
func containsUnsupportedFields(iptip linux.IPTIP) bool {
- // Currently we check that everything except protocol is zeroed.
+ // The following features are supported:
+ // - Protocol
+ // - Dst and DstMask
+ // - The inverse destination IP check flag
var emptyInetAddr = linux.InetAddr{}
var emptyInterface = [linux.IFNAMSIZ]byte{}
- return iptip.Dst != emptyInetAddr ||
- iptip.Src != emptyInetAddr ||
+ // Disable any supported inverse flags.
+ inverseMask := uint8(linux.IPT_INV_DSTIP)
+ return iptip.Src != emptyInetAddr ||
iptip.SrcMask != emptyInetAddr ||
- iptip.DstMask != emptyInetAddr ||
iptip.InputInterface != emptyInterface ||
iptip.OutputInterface != emptyInterface ||
iptip.InputInterfaceMask != emptyInterface ||
iptip.OutputInterfaceMask != emptyInterface ||
iptip.Flags != 0 ||
- iptip.InverseFlags != 0
+ iptip.InverseFlags&^inverseMask != 0
}
-func validUnderflow(rule iptables.Rule) bool {
+func validUnderflow(rule stack.Rule) bool {
if len(rule.Matchers) != 0 {
return false
}
switch rule.Target.(type) {
- case iptables.AcceptTarget, iptables.DropTarget:
+ case stack.AcceptTarget, stack.DropTarget:
return true
default:
return false
}
}
-func hookFromLinux(hook int) iptables.Hook {
+func hookFromLinux(hook int) stack.Hook {
switch hook {
case linux.NF_INET_PRE_ROUTING:
- return iptables.Prerouting
+ return stack.Prerouting
case linux.NF_INET_LOCAL_IN:
- return iptables.Input
+ return stack.Input
case linux.NF_INET_FORWARD:
- return iptables.Forward
+ return stack.Forward
case linux.NF_INET_LOCAL_OUT:
- return iptables.Output
+ return stack.Output
case linux.NF_INET_POST_ROUTING:
- return iptables.Postrouting
+ return stack.Postrouting
}
panic(fmt.Sprintf("Unknown hook %d does not correspond to a builtin chain", hook))
}
diff --git a/pkg/sentry/socket/netfilter/owner_matcher.go b/pkg/sentry/socket/netfilter/owner_matcher.go
new file mode 100644
index 000000000..5949a7c29
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/owner_matcher.go
@@ -0,0 +1,128 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+const matcherNameOwner = "owner"
+
+func init() {
+ registerMatchMaker(ownerMarshaler{})
+}
+
+// ownerMarshaler implements matchMaker for owner matching.
+type ownerMarshaler struct{}
+
+// name implements matchMaker.name.
+func (ownerMarshaler) name() string {
+ return matcherNameOwner
+}
+
+// marshal implements matchMaker.marshal.
+func (ownerMarshaler) marshal(mr stack.Matcher) []byte {
+ matcher := mr.(*OwnerMatcher)
+ iptOwnerInfo := linux.IPTOwnerInfo{
+ UID: matcher.uid,
+ GID: matcher.gid,
+ }
+
+ // Support for UID match.
+ // TODO(gvisor.dev/issue/170): Need to support gid match.
+ if matcher.matchUID {
+ iptOwnerInfo.Match = linux.XT_OWNER_UID
+ } else if matcher.matchGID {
+ panic("GID match is not supported.")
+ } else {
+ panic("UID match is not set.")
+ }
+
+ buf := make([]byte, 0, linux.SizeOfIPTOwnerInfo)
+ return marshalEntryMatch(matcherNameOwner, binary.Marshal(buf, usermem.ByteOrder, iptOwnerInfo))
+}
+
+// unmarshal implements matchMaker.unmarshal.
+func (ownerMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
+ if len(buf) < linux.SizeOfIPTOwnerInfo {
+ return nil, fmt.Errorf("buf has insufficient size for owner match: %d", len(buf))
+ }
+
+ // For alignment reasons, the match's total size may
+ // exceed what's strictly necessary to hold matchData.
+ var matchData linux.IPTOwnerInfo
+ binary.Unmarshal(buf[:linux.SizeOfIPTOwnerInfo], usermem.ByteOrder, &matchData)
+ nflog("parseMatchers: parsed IPTOwnerInfo: %+v", matchData)
+
+ if matchData.Invert != 0 {
+ return nil, fmt.Errorf("invert flag is not supported for owner match")
+ }
+
+ // Support for UID match.
+ // TODO(gvisor.dev/issue/170): Need to support gid match.
+ if matchData.Match&linux.XT_OWNER_UID != linux.XT_OWNER_UID {
+ return nil, fmt.Errorf("owner match is only supported for uid")
+ }
+
+ // Check Flags.
+ var owner OwnerMatcher
+ owner.uid = matchData.UID
+ owner.gid = matchData.GID
+ owner.matchUID = true
+
+ return &owner, nil
+}
+
+type OwnerMatcher struct {
+ uid uint32
+ gid uint32
+ matchUID bool
+ matchGID bool
+ invert uint8
+}
+
+// Name implements Matcher.Name.
+func (*OwnerMatcher) Name() string {
+ return matcherNameOwner
+}
+
+// Match implements Matcher.Match.
+func (om *OwnerMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
+ // Support only for OUTPUT chain.
+ // TODO(gvisor.dev/issue/170): Need to support for POSTROUTING chain also.
+ if hook != stack.Output {
+ return false, true
+ }
+
+ // If the packet owner is not set, drop the packet.
+ // Support for uid match.
+ // TODO(gvisor.dev/issue/170): Need to support gid match.
+ if pkt.Owner == nil || !om.matchUID {
+ return false, true
+ }
+
+ // TODO(gvisor.dev/issue/170): Need to add tests to verify
+ // drop rule when packet UID does not match owner matcher UID.
+ if pkt.Owner.UID() != om.uid {
+ return false, false
+ }
+
+ return true, false
+}
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index c421b87cf..c948de876 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -15,11 +15,10 @@
package netfilter
import (
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
)
-// JumpTarget implements iptables.Target.
+// JumpTarget implements stack.Target.
type JumpTarget struct {
// Offset is the byte offset of the rule to jump to. It is used for
// marshaling and unmarshaling.
@@ -29,7 +28,7 @@ type JumpTarget struct {
RuleNum int
}
-// Action implements iptables.Target.Action.
-func (jt JumpTarget) Action(tcpip.PacketBuffer) (iptables.RuleVerdict, int) {
- return iptables.RuleJump, jt.RuleNum
+// Action implements stack.Target.Action.
+func (jt JumpTarget) Action(stack.PacketBuffer) (stack.RuleVerdict, int) {
+ return stack.RuleJump, jt.RuleNum
}
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index f9945e214..ff1cfd8f6 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -19,9 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -40,7 +39,7 @@ func (tcpMarshaler) name() string {
}
// marshal implements matchMaker.marshal.
-func (tcpMarshaler) marshal(mr iptables.Matcher) []byte {
+func (tcpMarshaler) marshal(mr stack.Matcher) []byte {
matcher := mr.(*TCPMatcher)
xttcp := linux.XTTCP{
SourcePortStart: matcher.sourcePortStart,
@@ -53,7 +52,7 @@ func (tcpMarshaler) marshal(mr iptables.Matcher) []byte {
}
// unmarshal implements matchMaker.unmarshal.
-func (tcpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) {
+func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
if len(buf) < linux.SizeOfXTTCP {
return nil, fmt.Errorf("buf has insufficient size for TCP match: %d", len(buf))
}
@@ -97,7 +96,7 @@ func (*TCPMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (tm *TCPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) {
+func (tm *TCPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
netHeader := header.IPv4(pkt.NetworkHeader)
if netHeader.TransportProtocol() != header.TCPProtocolNumber {
@@ -115,7 +114,7 @@ func (tm *TCPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfac
// Now we need the transport header. However, this may not have been set
// yet.
// TODO(gvisor.dev/issue/170): Parsing the transport header should
- // ultimately be moved into the iptables.Check codepath as matchers are
+ // ultimately be moved into the stack.Check codepath as matchers are
// added.
var tcpHeader header.TCP
if pkt.TransportHeader != nil {
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index 86aa11696..3359418c1 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -19,9 +19,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
- "gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -40,7 +39,7 @@ func (udpMarshaler) name() string {
}
// marshal implements matchMaker.marshal.
-func (udpMarshaler) marshal(mr iptables.Matcher) []byte {
+func (udpMarshaler) marshal(mr stack.Matcher) []byte {
matcher := mr.(*UDPMatcher)
xtudp := linux.XTUDP{
SourcePortStart: matcher.sourcePortStart,
@@ -53,7 +52,7 @@ func (udpMarshaler) marshal(mr iptables.Matcher) []byte {
}
// unmarshal implements matchMaker.unmarshal.
-func (udpMarshaler) unmarshal(buf []byte, filter iptables.IPHeaderFilter) (iptables.Matcher, error) {
+func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Matcher, error) {
if len(buf) < linux.SizeOfXTUDP {
return nil, fmt.Errorf("buf has insufficient size for UDP match: %d", len(buf))
}
@@ -94,11 +93,11 @@ func (*UDPMatcher) Name() string {
}
// Match implements Matcher.Match.
-func (um *UDPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfaceName string) (bool, bool) {
+func (um *UDPMatcher) Match(hook stack.Hook, pkt stack.PacketBuffer, interfaceName string) (bool, bool) {
netHeader := header.IPv4(pkt.NetworkHeader)
// TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
- // into the iptables.Check codepath as matchers are added.
+ // into the stack.Check codepath as matchers are added.
if netHeader.TransportProtocol() != header.UDPProtocolNumber {
return false, false
}
@@ -114,7 +113,7 @@ func (um *UDPMatcher) Match(hook iptables.Hook, pkt tcpip.PacketBuffer, interfac
// Now we need the transport header. However, this may not have been set
// yet.
// TODO(gvisor.dev/issue/170): Parsing the transport header should
- // ultimately be moved into the iptables.Check codepath as matchers are
+ // ultimately be moved into the stack.Check codepath as matchers are
// added.
var udpHeader header.UDP
if pkt.TransportHeader != nil {
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index ab01cb4fa..cbf46b1e9 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -38,7 +38,6 @@ go_library(
"//pkg/tcpip",
"//pkg/tcpip/buffer",
"//pkg/tcpip/header",
- "//pkg/tcpip/iptables",
"//pkg/tcpip/network/ipv4",
"//pkg/tcpip/network/ipv6",
"//pkg/tcpip/stack",
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index e187276c5..7ac38764d 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -29,6 +29,7 @@ import (
"io"
"math"
"reflect"
+ "sync/atomic"
"syscall"
"time"
@@ -264,6 +265,12 @@ type SocketOperations struct {
skType linux.SockType
protocol int
+ // readViewHasData is 1 iff readView has data to be read, 0 otherwise.
+ // Must be accessed using atomic operations. It must only be written
+ // with readMu held but can be read without holding readMu. The latter
+ // is required to avoid deadlocks in epoll Readiness checks.
+ readViewHasData uint32
+
// readMu protects access to the below fields.
readMu sync.Mutex `state:"nosave"`
// readView contains the remaining payload from the last packet.
@@ -293,7 +300,7 @@ type SocketOperations struct {
// New creates a new endpoint socket.
func New(t *kernel.Task, family int, skType linux.SockType, protocol int, queue *waiter.Queue, endpoint tcpip.Endpoint) (*fs.File, *syserr.Error) {
if skType == linux.SOCK_STREAM {
- if err := endpoint.SetSockOptInt(tcpip.DelayOption, 1); err != nil {
+ if err := endpoint.SetSockOptBool(tcpip.DelayOption, true); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
}
@@ -410,21 +417,24 @@ func (s *SocketOperations) isPacketBased() bool {
// fetchReadView updates the readView field of the socket if it's currently
// empty. It assumes that the socket is locked.
+//
+// Precondition: s.readMu must be held.
func (s *SocketOperations) fetchReadView() *syserr.Error {
if len(s.readView) > 0 {
return nil
}
-
s.readView = nil
s.sender = tcpip.FullAddress{}
v, cms, err := s.Endpoint.Read(&s.sender)
if err != nil {
+ atomic.StoreUint32(&s.readViewHasData, 0)
return syserr.TranslateNetstackError(err)
}
s.readView = v
s.readCM = cms
+ atomic.StoreUint32(&s.readViewHasData, 1)
return nil
}
@@ -525,7 +535,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
}
if resCh != nil {
- t := ctx.(*kernel.Task)
+ t := kernel.TaskFromContext(ctx)
if err := t.Block(resCh); err != nil {
return 0, syserr.FromError(err).ToError()
}
@@ -598,7 +608,7 @@ func (s *SocketOperations) ReadFrom(ctx context.Context, _ *fs.File, r io.Reader
}
if resCh != nil {
- t := ctx.(*kernel.Task)
+ t := kernel.TaskFromContext(ctx)
if err := t.Block(resCh); err != nil {
return 0, syserr.FromError(err).ToError()
}
@@ -623,11 +633,9 @@ func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
// Check our cached value iff the caller asked for readability and the
// endpoint itself is currently not readable.
if (mask & ^r & waiter.EventIn) != 0 {
- s.readMu.Lock()
- if len(s.readView) > 0 {
+ if atomic.LoadUint32(&s.readViewHasData) == 1 {
r |= waiter.EventIn
}
- s.readMu.Unlock()
}
return r
@@ -655,7 +663,7 @@ func (s *SocketOperations) checkFamily(family uint16, exact bool) *syserr.Error
// This is a hack to work around the fact that both IPv4 and IPv6 ANY are
// represented by the empty string.
//
-// TODO(gvisor.dev/issues/1556): remove this function.
+// TODO(gvisor.dev/issue/1556): remove this function.
func (s *SocketOperations) mapFamily(addr tcpip.FullAddress, family uint16) tcpip.FullAddress {
if len(addr.Addr) == 0 && s.family == linux.AF_INET6 && family == linux.AF_INET {
addr.Addr = "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\xff\xff\x00\x00\x00\x00"
@@ -712,14 +720,44 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
// Bind implements the linux syscall bind(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
- addr, family, err := AddressAndFamily(sockaddr)
- if err != nil {
- return err
+ if len(sockaddr) < 2 {
+ return syserr.ErrInvalidArgument
}
- if err := s.checkFamily(family, true /* exact */); err != nil {
- return err
+
+ family := usermem.ByteOrder.Uint16(sockaddr)
+ var addr tcpip.FullAddress
+
+ // Bind for AF_PACKET requires only family, protocol and ifindex.
+ // In function AddressAndFamily, we check the address length which is
+ // not needed for AF_PACKET bind.
+ if family == linux.AF_PACKET {
+ var a linux.SockAddrLink
+ if len(sockaddr) < sockAddrLinkSize {
+ return syserr.ErrInvalidArgument
+ }
+ binary.Unmarshal(sockaddr[:sockAddrLinkSize], usermem.ByteOrder, &a)
+
+ if a.Protocol != uint16(s.protocol) {
+ return syserr.ErrInvalidArgument
+ }
+
+ addr = tcpip.FullAddress{
+ NIC: tcpip.NICID(a.InterfaceIndex),
+ Addr: tcpip.Address(a.HardwareAddr[:header.EthernetAddressSize]),
+ }
+ } else {
+ var err *syserr.Error
+ addr, family, err = AddressAndFamily(sockaddr)
+ if err != nil {
+ return err
+ }
+
+ if err = s.checkFamily(family, true /* exact */); err != nil {
+ return err
+ }
+
+ addr = s.mapFamily(addr, family)
}
- addr = s.mapFamily(addr, family)
// Issue the bind request to the endpoint.
return syserr.TranslateNetstackError(s.Endpoint.Bind(addr))
@@ -902,7 +940,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
// sockets backed by a commonEndpoint.
-func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
+func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (interface{}, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
@@ -927,8 +965,15 @@ func GetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int,
return nil, syserr.ErrProtocolNotAvailable
}
+func boolToInt32(v bool) int32 {
+ if v {
+ return 1
+ }
+ return 0
+}
+
// getSockOptSocket implements GetSockOpt when level is SOL_SOCKET.
-func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
+func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, name, outLen int) (interface{}, *syserr.Error) {
// TODO(b/124056281): Stop rejecting short optLen values in getsockopt.
switch name {
case linux.SO_ERROR:
@@ -960,12 +1005,11 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.PasscredOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.PasscredOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_SNDBUF:
if outLen < sizeOfInt32 {
@@ -1004,24 +1048,22 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.ReuseAddressOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.ReuseAddressOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_REUSEPORT:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.ReusePortOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.ReusePortOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_BINDTODEVICE:
var v tcpip.BindToDeviceOption
@@ -1051,24 +1093,22 @@ func getSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, family
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.BroadcastOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.BroadcastOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_KEEPALIVE:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.KeepaliveEnabledOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.KeepaliveEnabledOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.SO_LINGER:
if outLen < linux.SizeOfLinger {
@@ -1118,47 +1158,41 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (interfa
return nil, syserr.ErrInvalidArgument
}
- v, err := ep.GetSockOptInt(tcpip.DelayOption)
+ v, err := ep.GetSockOptBool(tcpip.DelayOption)
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- if v == 0 {
- return int32(1), nil
- }
- return int32(0), nil
+ return boolToInt32(!v), nil
case linux.TCP_CORK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.CorkOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.CorkOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.TCP_QUICKACK:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.QuickAckOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.QuickAckOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- return int32(v), nil
+ return boolToInt32(v), nil
case linux.TCP_MAXSEG:
if outLen < sizeOfInt32 {
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.MaxSegOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.MaxSegOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1290,11 +1324,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- var o int32
- if v {
- o = 1
- }
- return o, nil
+ return boolToInt32(v), nil
case linux.IPV6_PATHMTU:
t.Kernel().EmitUnimplementedEvent(t)
@@ -1304,8 +1334,8 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if outLen == 0 {
return make([]byte, 0), nil
}
- var v tcpip.IPv6TrafficClassOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.IPv6TrafficClassOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1327,12 +1357,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (interf
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- var o int32
- if v {
- o = 1
- }
- return o, nil
+ return boolToInt32(v), nil
default:
emitUnimplementedEventIPv6(t, name)
@@ -1348,8 +1373,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.TTLOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.TTLOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1365,8 +1390,8 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.MulticastTTLOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.MulticastTTLOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
@@ -1391,23 +1416,19 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
return nil, syserr.ErrInvalidArgument
}
- var v tcpip.MulticastLoopOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptBool(tcpip.MulticastLoopOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- if v {
- return int32(1), nil
- }
- return int32(0), nil
+ return boolToInt32(v), nil
case linux.IP_TOS:
// Length handling for parity with Linux.
if outLen == 0 {
return []byte(nil), nil
}
- var v tcpip.IPv4TOSOption
- if err := ep.GetSockOpt(&v); err != nil {
+ v, err := ep.GetSockOptInt(tcpip.IPv4TOSOption)
+ if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
if outLen < sizeOfInt32 {
@@ -1424,11 +1445,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- var o int32
- if v {
- o = 1
- }
- return o, nil
+ return boolToInt32(v), nil
case linux.IP_PKTINFO:
if outLen < sizeOfInt32 {
@@ -1439,11 +1456,7 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
if err != nil {
return nil, syserr.TranslateNetstackError(err)
}
- var o int32
- if v {
- o = 1
- }
- return o, nil
+ return boolToInt32(v), nil
default:
emitUnimplementedEventIP(t, name)
@@ -1503,7 +1516,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
// SetSockOpt can be used to implement the linux syscall setsockopt(2) for
// sockets backed by a commonEndpoint.
-func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error {
+func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int, name int, optVal []byte) *syserr.Error {
switch level {
case linux.SOL_SOCKET:
return setSockOptSocket(t, s, ep, name, optVal)
@@ -1530,7 +1543,7 @@ func SetSockOpt(t *kernel.Task, s socket.Socket, ep commonEndpoint, level int, n
}
// setSockOptSocket implements SetSockOpt when level is SOL_SOCKET.
-func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
switch name {
case linux.SO_SNDBUF:
if len(optVal) < sizeOfInt32 {
@@ -1554,7 +1567,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReuseAddressOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReuseAddressOption, v != 0))
case linux.SO_REUSEPORT:
if len(optVal) < sizeOfInt32 {
@@ -1562,7 +1575,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.ReusePortOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReusePortOption, v != 0))
case linux.SO_BINDTODEVICE:
n := bytes.IndexByte(optVal, 0)
@@ -1590,7 +1603,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BroadcastOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.BroadcastOption, v != 0))
case linux.SO_PASSCRED:
if len(optVal) < sizeOfInt32 {
@@ -1598,7 +1611,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.PasscredOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.PasscredOption, v != 0))
case linux.SO_KEEPALIVE:
if len(optVal) < sizeOfInt32 {
@@ -1606,7 +1619,7 @@ func setSockOptSocket(t *kernel.Task, s socket.Socket, ep commonEndpoint, name i
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveEnabledOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.KeepaliveEnabledOption, v != 0))
case linux.SO_SNDTIMEO:
if len(optVal) < linux.SizeOfTimeval {
@@ -1678,11 +1691,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- var o int
- if v == 0 {
- o = 1
- }
- return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.DelayOption, o))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.DelayOption, v == 0))
case linux.TCP_CORK:
if len(optVal) < sizeOfInt32 {
@@ -1690,7 +1699,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.CorkOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.CorkOption, v != 0))
case linux.TCP_QUICKACK:
if len(optVal) < sizeOfInt32 {
@@ -1698,7 +1707,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.QuickAckOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.QuickAckOption, v != 0))
case linux.TCP_MAXSEG:
if len(optVal) < sizeOfInt32 {
@@ -1706,7 +1715,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MaxSegOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MaxSegOption, int(v)))
case linux.TCP_KEEPIDLE:
if len(optVal) < sizeOfInt32 {
@@ -1817,7 +1826,7 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
if v == -1 {
v = 0
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv6TrafficClassOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv6TrafficClassOption, int(v)))
case linux.IPV6_RECVTCLASS:
v, err := parseIntOrChar(optVal)
@@ -1902,7 +1911,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
if v < 0 || v > 255 {
return syserr.ErrInvalidArgument
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastTTLOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.MulticastTTLOption, int(v)))
case linux.IP_ADD_MEMBERSHIP:
req, err := copyInMulticastRequest(optVal, false /* allowAddr */)
@@ -1949,9 +1958,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(
- tcpip.MulticastLoopOption(v != 0),
- ))
+ return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.MulticastLoopOption, v != 0))
case linux.MCAST_JOIN_GROUP:
// FIXME(b/124219304): Implement MCAST_JOIN_GROUP.
@@ -1970,7 +1977,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
} else if v < 1 || v > 255 {
return syserr.ErrInvalidArgument
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TTLOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.TTLOption, int(v)))
case linux.IP_TOS:
if len(optVal) == 0 {
@@ -1980,7 +1987,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
if err != nil {
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.IPv4TOSOption(v)))
+ return syserr.TranslateNetstackError(ep.SetSockOptInt(tcpip.IPv4TOSOption, int(v)))
case linux.IP_RECVTOS:
v, err := parseIntOrChar(optVal)
@@ -2304,6 +2311,10 @@ func (s *SocketOperations) coalescingRead(ctx context.Context, dst usermem.IOSeq
}
copied += n
s.readView.TrimFront(n)
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
+ }
+
dst = dst.DropFirst(n)
if e != nil {
err = syserr.FromError(e)
@@ -2350,9 +2361,9 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
// caller-supplied buffer.
s.readMu.Lock()
n, err := s.coalescingRead(ctx, dst, trunc)
- s.readMu.Unlock()
cmsg := s.controlMessages()
s.fillCmsgInq(&cmsg)
+ s.readMu.Unlock()
return n, 0, nil, 0, cmsg, err
}
@@ -2426,6 +2437,10 @@ func (s *SocketOperations) nonBlockingRead(ctx context.Context, dst usermem.IOSe
s.readView.TrimFront(int(n))
}
+ if len(s.readView) == 0 {
+ atomic.StoreUint32(&s.readViewHasData, 0)
+ }
+
var flags int
if msgLen > int(n) {
flags |= linux.MSG_TRUNC
@@ -2637,7 +2652,9 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
}
// Add bytes removed from the endpoint but not yet sent to the caller.
+ s.readMu.Lock()
v += len(s.readView)
+ s.readMu.Unlock()
if v > math.MaxInt32 {
v = math.MaxInt32
diff --git a/pkg/sentry/socket/netstack/provider.go b/pkg/sentry/socket/netstack/provider.go
index 5f181f017..c3f04b613 100644
--- a/pkg/sentry/socket/netstack/provider.go
+++ b/pkg/sentry/socket/netstack/provider.go
@@ -62,10 +62,6 @@ func getTransportProtocol(ctx context.Context, stype linux.SockType, protocol in
}
case linux.SOCK_RAW:
- // TODO(b/142504697): "In order to create a raw socket, a
- // process must have the CAP_NET_RAW capability in the user
- // namespace that governs its network namespace." - raw(7)
-
// Raw sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(ctx)
if !creds.HasCapability(linux.CAP_NET_RAW) {
@@ -126,6 +122,12 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
ep, e = eps.Stack.NewRawEndpoint(transProto, p.netProto, wq, associated)
} else {
ep, e = eps.Stack.NewEndpoint(transProto, p.netProto, wq)
+
+ // Assign task to PacketOwner interface to get the UID and GID for
+ // iptables owner matching.
+ if e == nil {
+ ep.SetOwner(t)
+ }
}
if e != nil {
return nil, syserr.TranslateNetstackError(e)
@@ -135,10 +137,6 @@ func (p *provider) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*
}
func packetSocket(t *kernel.Task, epStack *Stack, stype linux.SockType, protocol int) (*fs.File, *syserr.Error) {
- // TODO(b/142504697): "In order to create a packet socket, a process
- // must have the CAP_NET_RAW capability in the user namespace that
- // governs its network namespace." - packet(7)
-
// Packet sockets require CAP_NET_RAW.
creds := auth.CredentialsFromContext(t)
if !creds.HasCapability(linux.CAP_NET_RAW) {
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 0692482e9..f5fa18136 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -23,7 +23,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/header"
- "gvisor.dev/gvisor/pkg/tcpip/iptables"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
"gvisor.dev/gvisor/pkg/tcpip/network/ipv6"
"gvisor.dev/gvisor/pkg/tcpip/stack"
@@ -200,36 +199,66 @@ func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
// Statistics implements inet.Stack.Statistics.
func (s *Stack) Statistics(stat interface{}, arg string) error {
switch stats := stat.(type) {
+ case *inet.StatDev:
+ for _, ni := range s.Stack.NICInfo() {
+ if ni.Name != arg {
+ continue
+ }
+ // TODO(gvisor.dev/issue/2103) Support stubbed stats.
+ *stats = inet.StatDev{
+ // Receive section.
+ ni.Stats.Rx.Bytes.Value(), // bytes.
+ ni.Stats.Rx.Packets.Value(), // packets.
+ 0, // errs.
+ 0, // drop.
+ 0, // fifo.
+ 0, // frame.
+ 0, // compressed.
+ 0, // multicast.
+ // Transmit section.
+ ni.Stats.Tx.Bytes.Value(), // bytes.
+ ni.Stats.Tx.Packets.Value(), // packets.
+ 0, // errs.
+ 0, // drop.
+ 0, // fifo.
+ 0, // colls.
+ 0, // carrier.
+ 0, // compressed.
+ }
+ break
+ }
case *inet.StatSNMPIP:
ip := Metrics.IP
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPIP{
- 0, // TODO(gvisor.dev/issue/969): Support Ip/Forwarding.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/DefaultTTL.
+ 0, // Ip/Forwarding.
+ 0, // Ip/DefaultTTL.
ip.PacketsReceived.Value(), // InReceives.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/InHdrErrors.
+ 0, // Ip/InHdrErrors.
ip.InvalidDestinationAddressesReceived.Value(), // InAddrErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ForwDatagrams.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/InUnknownProtos.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/InDiscards.
+ 0, // Ip/ForwDatagrams.
+ 0, // Ip/InUnknownProtos.
+ 0, // Ip/InDiscards.
ip.PacketsDelivered.Value(), // InDelivers.
ip.PacketsSent.Value(), // OutRequests.
ip.OutgoingPacketErrors.Value(), // OutDiscards.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/OutNoRoutes.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmTimeout.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmReqds.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmOKs.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/ReasmFails.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/FragOKs.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/FragFails.
- 0, // TODO(gvisor.dev/issue/969): Support Ip/FragCreates.
+ 0, // Ip/OutNoRoutes.
+ 0, // Support Ip/ReasmTimeout.
+ 0, // Support Ip/ReasmReqds.
+ 0, // Support Ip/ReasmOKs.
+ 0, // Support Ip/ReasmFails.
+ 0, // Support Ip/FragOKs.
+ 0, // Support Ip/FragFails.
+ 0, // Support Ip/FragCreates.
}
case *inet.StatSNMPICMP:
in := Metrics.ICMP.V4PacketsReceived.ICMPv4PacketStats
out := Metrics.ICMP.V4PacketsSent.ICMPv4PacketStats
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPICMP{
- 0, // TODO(gvisor.dev/issue/969): Support Icmp/InMsgs.
+ 0, // Icmp/InMsgs.
Metrics.ICMP.V4PacketsSent.Dropped.Value(), // InErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Icmp/InCsumErrors.
+ 0, // Icmp/InCsumErrors.
in.DstUnreachable.Value(), // InDestUnreachs.
in.TimeExceeded.Value(), // InTimeExcds.
in.ParamProblem.Value(), // InParmProbs.
@@ -241,7 +270,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
in.TimestampReply.Value(), // InTimestampReps.
in.InfoRequest.Value(), // InAddrMasks.
in.InfoReply.Value(), // InAddrMaskReps.
- 0, // TODO(gvisor.dev/issue/969): Support Icmp/OutMsgs.
+ 0, // Icmp/OutMsgs.
Metrics.ICMP.V4PacketsReceived.Invalid.Value(), // OutErrors.
out.DstUnreachable.Value(), // OutDestUnreachs.
out.TimeExceeded.Value(), // OutTimeExcds.
@@ -277,15 +306,16 @@ func (s *Stack) Statistics(stat interface{}, arg string) error {
}
case *inet.StatSNMPUDP:
udp := Metrics.UDP
+ // TODO(gvisor.dev/issue/969) Support stubbed stats.
*stats = inet.StatSNMPUDP{
udp.PacketsReceived.Value(), // InDatagrams.
udp.UnknownPortErrors.Value(), // NoPorts.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/InErrors.
+ 0, // Udp/InErrors.
udp.PacketsSent.Value(), // OutDatagrams.
udp.ReceiveBufferErrors.Value(), // RcvbufErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/SndbufErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/InCsumErrors.
- 0, // TODO(gvisor.dev/issue/969): Support Udp/IgnoredMulti.
+ 0, // Udp/SndbufErrors.
+ 0, // Udp/InCsumErrors.
+ 0, // Udp/IgnoredMulti.
}
default:
return syserr.ErrEndpointOperation.ToError()
@@ -332,7 +362,7 @@ func (s *Stack) RouteTable() []inet.Route {
}
// IPTables returns the stack's iptables.
-func (s *Stack) IPTables() (iptables.IPTables, error) {
+func (s *Stack) IPTables() (stack.IPTables, error) {
return s.Stack.IPTables(), nil
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 50d9744e6..6580bd6e9 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -31,6 +31,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/kernel"
ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
@@ -48,11 +49,25 @@ func (c *ControlMessages) Release() {
c.Unix.Release()
}
-// Socket is the interface containing socket syscalls used by the syscall layer
-// to redirect them to the appropriate implementation.
+// Socket is an interface combining fs.FileOperations and SocketOps,
+// representing a VFS1 socket file.
type Socket interface {
fs.FileOperations
+ SocketOps
+}
+
+// SocketVFS2 is an interface combining vfs.FileDescription and SocketOps,
+// representing a VFS2 socket file.
+type SocketVFS2 interface {
+ vfs.FileDescriptionImpl
+ SocketOps
+}
+// SocketOps is the interface containing socket syscalls used by the syscall
+// layer to redirect them to the appropriate implementation.
+//
+// It is implemented by both Socket and SocketVFS2.
+type SocketOps interface {
// Connect implements the connect(2) linux syscall.
Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error
@@ -153,6 +168,8 @@ var families = make(map[int][]Provider)
// RegisterProvider registers the provider of a given address family so that
// sockets of that type can be created via socket() and/or socketpair()
// syscalls.
+//
+// This should only be called during the initialization of the address family.
func RegisterProvider(family int, provider Provider) {
families[family] = append(families[family], provider)
}
@@ -216,6 +233,74 @@ func NewDirent(ctx context.Context, d *device.Device) *fs.Dirent {
return fs.NewDirent(ctx, inode, fmt.Sprintf("socket:[%d]", ino))
}
+// ProviderVFS2 is the vfs2 interface implemented by providers of sockets for
+// specific address families (e.g., AF_INET).
+type ProviderVFS2 interface {
+ // Socket creates a new socket.
+ //
+ // If a nil Socket _and_ a nil error is returned, it means that the
+ // protocol is not supported. A non-nil error should only be returned
+ // if the protocol is supported, but an error occurs during creation.
+ Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error)
+
+ // Pair creates a pair of connected sockets.
+ //
+ // See Socket for error information.
+ Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error)
+}
+
+// familiesVFS2 holds a map of all known address families and their providers.
+var familiesVFS2 = make(map[int][]ProviderVFS2)
+
+// RegisterProviderVFS2 registers the provider of a given address family so that
+// sockets of that type can be created via socket() and/or socketpair()
+// syscalls.
+//
+// This should only be called during the initialization of the address family.
+func RegisterProviderVFS2(family int, provider ProviderVFS2) {
+ familiesVFS2[family] = append(familiesVFS2[family], provider)
+}
+
+// NewVFS2 creates a new socket with the given family, type and protocol.
+func NewVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ for _, p := range familiesVFS2[family] {
+ s, err := p.Socket(t, stype, protocol)
+ if err != nil {
+ return nil, err
+ }
+ if s != nil {
+ t.Kernel().RecordSocketVFS2(s)
+ return s, nil
+ }
+ }
+
+ return nil, syserr.ErrAddressFamilyNotSupported
+}
+
+// PairVFS2 creates a new connected socket pair with the given family, type and
+// protocol.
+func PairVFS2(t *kernel.Task, family int, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ providers, ok := familiesVFS2[family]
+ if !ok {
+ return nil, nil, syserr.ErrAddressFamilyNotSupported
+ }
+
+ for _, p := range providers {
+ s1, s2, err := p.Pair(t, stype, protocol)
+ if err != nil {
+ return nil, nil, err
+ }
+ if s1 != nil && s2 != nil {
+ k := t.Kernel()
+ k.RecordSocketVFS2(s1)
+ k.RecordSocketVFS2(s2)
+ return s1, s2, nil
+ }
+ }
+
+ return nil, nil, syserr.ErrSocketNotSupported
+}
+
// SendReceiveTimeout stores timeouts for send and receive calls.
//
// It is meant to be embedded into Socket implementations to help satisfy the
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index 08743deba..de2cc4bdf 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -8,23 +8,27 @@ go_library(
"device.go",
"io.go",
"unix.go",
+ "unix_vfs2.go",
],
visibility = ["//pkg/sentry:internal"],
deps = [
"//pkg/abi/linux",
"//pkg/context",
+ "//pkg/fspath",
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
+ "//pkg/sentry/fsimpl/sockfs",
"//pkg/sentry/kernel",
"//pkg/sentry/kernel/time",
"//pkg/sentry/socket",
"//pkg/sentry/socket/control",
"//pkg/sentry/socket/netstack",
"//pkg/sentry/socket/unix/transport",
+ "//pkg/sentry/vfs",
"//pkg/syserr",
"//pkg/syserror",
"//pkg/tcpip",
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index 74bcd6300..c708b6030 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -30,6 +30,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/context",
"//pkg/ilist",
+ "//pkg/log",
"//pkg/refs",
"//pkg/sync",
"//pkg/syserr",
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 2ef654235..2f1b127df 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -20,6 +20,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -838,24 +839,43 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
// SetSockOpt sets a socket option. Currently not supported.
func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
- switch v := opt.(type) {
- case tcpip.PasscredOption:
- e.setPasscred(v != 0)
- return nil
- }
return nil
}
func (e *baseEndpoint) SetSockOptBool(opt tcpip.SockOptBool, v bool) *tcpip.Error {
+ switch opt {
+ case tcpip.BroadcastOption:
+ case tcpip.PasscredOption:
+ e.setPasscred(v)
+ case tcpip.ReuseAddressOption:
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ }
return nil
}
func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
+ switch opt {
+ case tcpip.SendBufferSizeOption:
+ case tcpip.ReceiveBufferSizeOption:
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ }
return nil
}
func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) {
- return false, tcpip.ErrUnknownProtocolOption
+ switch opt {
+ case tcpip.KeepaliveEnabledOption:
+ return false, nil
+
+ case tcpip.PasscredOption:
+ return e.Passcred(), nil
+
+ default:
+ log.Warningf("Unsupported socket option: %d", opt)
+ return false, tcpip.ErrUnknownProtocolOption
+ }
}
func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
@@ -914,29 +934,19 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
return int(v), nil
default:
+ log.Warningf("Unsupported socket option: %d", opt)
return -1, tcpip.ErrUnknownProtocolOption
}
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch o := opt.(type) {
+ switch opt.(type) {
case tcpip.ErrorOption:
return nil
- case *tcpip.PasscredOption:
- if e.Passcred() {
- *o = tcpip.PasscredOption(1)
- } else {
- *o = tcpip.PasscredOption(0)
- }
- return nil
-
- case *tcpip.KeepaliveEnabledOption:
- *o = 0
- return nil
-
default:
+ log.Warningf("Unsupported socket option: %T", opt)
return tcpip.ErrUnknownProtocolOption
}
}
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index 4d30aa714..7c64f30fa 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -23,6 +23,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
"gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -33,6 +34,7 @@ import (
"gvisor.dev/gvisor/pkg/sentry/socket/control"
"gvisor.dev/gvisor/pkg/sentry/socket/netstack"
"gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -52,11 +54,8 @@ type SocketOperations struct {
fsutil.FileNoSplice `state:"nosave"`
fsutil.FileNoopFlush `state:"nosave"`
fsutil.FileUseInodeUnstableAttr `state:"nosave"`
- refs.AtomicRefCount
- socket.SendReceiveTimeout
- ep transport.Endpoint
- stype linux.SockType
+ socketOpsCommon
}
// New creates a new unix socket.
@@ -75,16 +74,29 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty
}
s := SocketOperations{
- ep: ep,
- stype: stype,
+ socketOpsCommon: socketOpsCommon{
+ ep: ep,
+ stype: stype,
+ },
}
s.EnableLeakCheck("unix.SocketOperations")
return fs.NewFile(ctx, d, flags, &s)
}
+// socketOpsCommon contains the socket operations common to VFS1 and VFS2.
+//
+// +stateify savable
+type socketOpsCommon struct {
+ refs.AtomicRefCount
+ socket.SendReceiveTimeout
+
+ ep transport.Endpoint
+ stype linux.SockType
+}
+
// DecRef implements RefCounter.DecRef.
-func (s *SocketOperations) DecRef() {
+func (s *socketOpsCommon) DecRef() {
s.DecRefWithDestructor(func() {
s.ep.Close()
})
@@ -97,7 +109,7 @@ func (s *SocketOperations) Release() {
s.DecRef()
}
-func (s *SocketOperations) isPacket() bool {
+func (s *socketOpsCommon) isPacket() bool {
switch s.stype {
case linux.SOCK_DGRAM, linux.SOCK_SEQPACKET:
return true
@@ -110,7 +122,7 @@ func (s *SocketOperations) isPacket() bool {
}
// Endpoint extracts the transport.Endpoint.
-func (s *SocketOperations) Endpoint() transport.Endpoint {
+func (s *socketOpsCommon) Endpoint() transport.Endpoint {
return s.ep
}
@@ -143,7 +155,7 @@ func extractPath(sockaddr []byte) (string, *syserr.Error) {
// GetPeerName implements the linux syscall getpeername(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.ep.GetRemoteAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -155,7 +167,7 @@ func (s *SocketOperations) GetPeerName(t *kernel.Task) (linux.SockAddr, uint32,
// GetSockName implements the linux syscall getsockname(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
+func (s *socketOpsCommon) GetSockName(t *kernel.Task) (linux.SockAddr, uint32, *syserr.Error) {
addr, err := s.ep.GetLocalAddress()
if err != nil {
return nil, 0, syserr.TranslateNetstackError(err)
@@ -178,7 +190,7 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
// Listen implements the linux syscall listen(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) Listen(t *kernel.Task, backlog int) *syserr.Error {
+func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
return s.ep.Listen(backlog)
}
@@ -310,6 +322,8 @@ func (s *SocketOperations) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
// Create the socket.
+ //
+ // TODO(gvisor.dev/issue/2324): Correctly set file permissions.
childDir, err := d.Bind(t, t.FSContext().RootDirectory(), name, bep, fs.FilePermissions{User: fs.PermMask{Read: true}})
if err != nil {
return syserr.ErrPortInUse
@@ -345,6 +359,31 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint,
return ep, nil
}
+ if kernel.VFS2Enabled {
+ p := fspath.Parse(path)
+ root := t.FSContext().RootDirectoryVFS2()
+ start := root
+ relPath := !p.Absolute
+ if relPath {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ }
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: p,
+ FollowFinalSymlink: true,
+ }
+ ep, e := t.Kernel().VFS().BoundEndpointAt(t, t.Credentials(), &pop)
+ root.DecRef()
+ if relPath {
+ start.DecRef()
+ }
+ if e != nil {
+ return nil, syserr.FromError(e)
+ }
+ return ep, nil
+ }
+
// Find the node in the filesystem.
root := t.FSContext().RootDirectory()
cwd := t.FSContext().WorkingDirectory()
@@ -363,12 +402,11 @@ func extractEndpoint(t *kernel.Task, sockaddr []byte) (transport.BoundEndpoint,
// No socket!
return nil, syserr.ErrConnectionRefused
}
-
return ep, nil
}
// Connect implements the linux syscall connect(2) for unix sockets.
-func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
+func (s *socketOpsCommon) Connect(t *kernel.Task, sockaddr []byte, blocking bool) *syserr.Error {
ep, err := extractEndpoint(t, sockaddr)
if err != nil {
return err
@@ -379,7 +417,7 @@ func (s *SocketOperations) Connect(t *kernel.Task, sockaddr []byte, blocking boo
return s.ep.Connect(t, ep)
}
-// Writev implements fs.FileOperations.Write.
+// Write implements fs.FileOperations.Write.
func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IOSequence, _ int64) (int64, error) {
t := kernel.TaskFromContext(ctx)
ctrl := control.New(t, s.ep, nil)
@@ -399,7 +437,7 @@ func (s *SocketOperations) Write(ctx context.Context, _ *fs.File, src usermem.IO
// SendMsg implements the linux syscall sendmsg(2) for unix sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
+func (s *socketOpsCommon) SendMsg(t *kernel.Task, src usermem.IOSequence, to []byte, flags int, haveDeadline bool, deadline ktime.Time, controlMessages socket.ControlMessages) (int, *syserr.Error) {
w := EndpointWriter{
Ctx: t,
Endpoint: s.ep,
@@ -453,27 +491,27 @@ func (s *SocketOperations) SendMsg(t *kernel.Task, src usermem.IOSequence, to []
}
// Passcred implements transport.Credentialer.Passcred.
-func (s *SocketOperations) Passcred() bool {
+func (s *socketOpsCommon) Passcred() bool {
return s.ep.Passcred()
}
// ConnectedPasscred implements transport.Credentialer.ConnectedPasscred.
-func (s *SocketOperations) ConnectedPasscred() bool {
+func (s *socketOpsCommon) ConnectedPasscred() bool {
return s.ep.ConnectedPasscred()
}
// Readiness implements waiter.Waitable.Readiness.
-func (s *SocketOperations) Readiness(mask waiter.EventMask) waiter.EventMask {
+func (s *socketOpsCommon) Readiness(mask waiter.EventMask) waiter.EventMask {
return s.ep.Readiness(mask)
}
// EventRegister implements waiter.Waitable.EventRegister.
-func (s *SocketOperations) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+func (s *socketOpsCommon) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
s.ep.EventRegister(e, mask)
}
// EventUnregister implements waiter.Waitable.EventUnregister.
-func (s *SocketOperations) EventUnregister(e *waiter.Entry) {
+func (s *socketOpsCommon) EventUnregister(e *waiter.Entry) {
s.ep.EventUnregister(e)
}
@@ -485,7 +523,7 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
// Shutdown implements the linux syscall shutdown(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) Shutdown(t *kernel.Task, how int) *syserr.Error {
+func (s *socketOpsCommon) Shutdown(t *kernel.Task, how int) *syserr.Error {
f, err := netstack.ConvertShutdown(how)
if err != nil {
return err
@@ -511,7 +549,7 @@ func (s *SocketOperations) Read(ctx context.Context, _ *fs.File, dst usermem.IOS
// RecvMsg implements the linux syscall recvmsg(2) for sockets backed by
// a transport.Endpoint.
-func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
+func (s *socketOpsCommon) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags int, haveDeadline bool, deadline ktime.Time, senderRequested bool, controlDataLen uint64) (n int, msgFlags int, senderAddr linux.SockAddr, senderAddrLen uint32, controlMessages socket.ControlMessages, err *syserr.Error) {
trunc := flags&linux.MSG_TRUNC != 0
peek := flags&linux.MSG_PEEK != 0
dontWait := flags&linux.MSG_DONTWAIT != 0
@@ -648,12 +686,12 @@ func (s *SocketOperations) RecvMsg(t *kernel.Task, dst usermem.IOSequence, flags
}
// State implements socket.Socket.State.
-func (s *SocketOperations) State() uint32 {
+func (s *socketOpsCommon) State() uint32 {
return s.ep.State()
}
// Type implements socket.Socket.Type.
-func (s *SocketOperations) Type() (family int, skType linux.SockType, protocol int) {
+func (s *socketOpsCommon) Type() (family int, skType linux.SockType, protocol int) {
// Unix domain sockets always have a protocol of 0.
return linux.AF_UNIX, s.stype, 0
}
@@ -706,4 +744,5 @@ func (*provider) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*fs.F
func init() {
socket.RegisterProvider(linux.AF_UNIX, &provider{})
+ socket.RegisterProviderVFS2(linux.AF_UNIX, &providerVFS2{})
}
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
new file mode 100644
index 000000000..3e54d49c4
--- /dev/null
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -0,0 +1,348 @@
+// Copyright 2018 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package unix
+
+import (
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/sentry/arch"
+ "gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/sentry/socket/control"
+ "gvisor.dev/gvisor/pkg/sentry/socket/netstack"
+ "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport"
+ "gvisor.dev/gvisor/pkg/sentry/vfs"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/syserror"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/usermem"
+ "gvisor.dev/gvisor/pkg/waiter"
+)
+
+// SocketVFS2 implements socket.SocketVFS2 (and by extension,
+// vfs.FileDescriptionImpl) for Unix sockets.
+type SocketVFS2 struct {
+ vfsfd vfs.FileDescription
+ vfs.FileDescriptionDefaultImpl
+ vfs.DentryMetadataFileDescriptionImpl
+
+ socketOpsCommon
+}
+
+// NewVFS2File creates and returns a new vfs.FileDescription for a unix socket.
+func NewVFS2File(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) {
+ sock := NewFDImpl(ep, stype)
+ vfsfd := &sock.vfsfd
+ if err := sockfs.InitSocket(sock, vfsfd, t.Kernel().SocketMount(), t.Credentials()); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ return vfsfd, nil
+}
+
+// NewFDImpl creates and returns a new SocketVFS2.
+func NewFDImpl(ep transport.Endpoint, stype linux.SockType) *SocketVFS2 {
+ // You can create AF_UNIX, SOCK_RAW sockets. They're the same as
+ // SOCK_DGRAM and don't require CAP_NET_RAW.
+ if stype == linux.SOCK_RAW {
+ stype = linux.SOCK_DGRAM
+ }
+
+ return &SocketVFS2{
+ socketOpsCommon: socketOpsCommon{
+ ep: ep,
+ stype: stype,
+ },
+ }
+}
+
+// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level int, name int, outPtr usermem.Addr, outLen int) (interface{}, *syserr.Error) {
+ return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
+}
+
+// blockingAccept implements a blocking version of accept(2), that is, if no
+// connections are ready to be accept, it will block until one becomes ready.
+func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+ // Register for notifications.
+ e, ch := waiter.NewChannelEntry(nil)
+ s.socketOpsCommon.EventRegister(&e, waiter.EventIn)
+ defer s.socketOpsCommon.EventUnregister(&e)
+
+ // Try to accept the connection; if it fails, then wait until we get a
+ // notification.
+ for {
+ if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ return ep, err
+ }
+
+ if err := t.Block(ch); err != nil {
+ return nil, syserr.FromError(err)
+ }
+ }
+}
+
+// Accept implements the linux syscall accept(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
+ // Issue the accept request to get the new endpoint.
+ ep, err := s.ep.Accept()
+ if err != nil {
+ if err != syserr.ErrWouldBlock || !blocking {
+ return 0, nil, 0, err
+ }
+
+ var err *syserr.Error
+ ep, err = s.blockingAccept(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ // We expect this to be a FileDescription here.
+ ns, err := NewVFS2File(t, ep, s.stype)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ defer ns.DecRef()
+
+ if flags&linux.SOCK_NONBLOCK != 0 {
+ ns.SetStatusFlags(t, t.Credentials(), linux.SOCK_NONBLOCK)
+ }
+
+ var addr linux.SockAddr
+ var addrLen uint32
+ if peerRequested {
+ // Get address of the peer.
+ var err *syserr.Error
+ addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t)
+ if err != nil {
+ return 0, nil, 0, err
+ }
+ }
+
+ fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
+ CloseOnExec: flags&linux.SOCK_CLOEXEC != 0,
+ })
+ if e != nil {
+ return 0, nil, 0, syserr.FromError(e)
+ }
+
+ t.Kernel().RecordSocketVFS2(ns)
+ return fd, addr, addrLen, nil
+}
+
+// Bind implements the linux syscall bind(2) for unix sockets.
+func (s *SocketVFS2) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
+ p, e := extractPath(sockaddr)
+ if e != nil {
+ return e
+ }
+
+ bep, ok := s.ep.(transport.BoundEndpoint)
+ if !ok {
+ // This socket can't be bound.
+ return syserr.ErrInvalidArgument
+ }
+
+ return s.ep.Bind(tcpip.FullAddress{Addr: tcpip.Address(p)}, func() *syserr.Error {
+ // Is it abstract?
+ if p[0] == 0 {
+ if t.IsNetworkNamespaced() {
+ return syserr.ErrInvalidEndpointState
+ }
+ if err := t.AbstractSockets().Bind(p[1:], bep, s); err != nil {
+ // syserr.ErrPortInUse corresponds to EADDRINUSE.
+ return syserr.ErrPortInUse
+ }
+ } else {
+ path := fspath.Parse(p)
+ root := t.FSContext().RootDirectoryVFS2()
+ defer root.DecRef()
+ start := root
+ relPath := !path.Absolute
+ if relPath {
+ start = t.FSContext().WorkingDirectoryVFS2()
+ defer start.DecRef()
+ }
+ pop := vfs.PathOperation{
+ Root: root,
+ Start: start,
+ Path: path,
+ }
+ err := t.Kernel().VFS().MknodAt(t, t.Credentials(), &pop, &vfs.MknodOptions{
+ // TODO(gvisor.dev/issue/2324): The file permissions should be taken
+ // from s and t.FSContext().Umask() (see net/unix/af_unix.c:unix_bind),
+ // but VFS1 just always uses 0400. Resolve this inconsistency.
+ Mode: linux.S_IFSOCK | 0400,
+ Endpoint: bep,
+ })
+ if err == syserror.EEXIST {
+ return syserr.ErrAddressInUse
+ }
+ return syserr.FromError(err)
+ }
+
+ return nil
+ })
+}
+
+// Ioctl implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.SyscallArguments) (uintptr, error) {
+ return netstack.Ioctl(ctx, s.ep, uio, args)
+}
+
+// PRead implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Read implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/1476): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ if dst.NumBytes() == 0 {
+ return 0, nil
+ }
+ return dst.CopyOutFrom(ctx, &EndpointReader{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ NumRights: 0,
+ Peek: false,
+ From: nil,
+ })
+}
+
+// PWrite implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) {
+ return 0, syserror.ESPIPE
+}
+
+// Write implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) {
+ // All flags other than RWF_NOWAIT should be ignored.
+ // TODO(gvisor.dev/issue/1476): Support RWF_NOWAIT.
+ if opts.Flags != 0 {
+ return 0, syserror.EOPNOTSUPP
+ }
+
+ t := kernel.TaskFromContext(ctx)
+ ctrl := control.New(t, s.ep, nil)
+
+ if src.NumBytes() == 0 {
+ nInt, err := s.ep.SendMsg(ctx, [][]byte{}, ctrl, nil)
+ return int64(nInt), err.ToError()
+ }
+
+ return src.CopyInTo(ctx, &EndpointWriter{
+ Ctx: ctx,
+ Endpoint: s.ep,
+ Control: ctrl,
+ To: nil,
+ })
+}
+
+// Release implements vfs.FileDescriptionImpl.
+func (s *SocketVFS2) Release() {
+ // Release only decrements a reference on s because s may be referenced in
+ // the abstract socket namespace.
+ s.DecRef()
+}
+
+// Readiness implements waiter.Waitable.Readiness.
+func (s *SocketVFS2) Readiness(mask waiter.EventMask) waiter.EventMask {
+ return s.socketOpsCommon.Readiness(mask)
+}
+
+// EventRegister implements waiter.Waitable.EventRegister.
+func (s *SocketVFS2) EventRegister(e *waiter.Entry, mask waiter.EventMask) {
+ s.socketOpsCommon.EventRegister(e, mask)
+}
+
+// EventUnregister implements waiter.Waitable.EventUnregister.
+func (s *SocketVFS2) EventUnregister(e *waiter.Entry) {
+ s.socketOpsCommon.EventUnregister(e)
+}
+
+// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
+// a transport.Endpoint.
+func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []byte) *syserr.Error {
+ return netstack.SetSockOpt(t, s, s.ep, level, name, optVal)
+}
+
+// providerVFS2 is a unix domain socket provider for VFS2.
+type providerVFS2 struct{}
+
+func (*providerVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, syserr.ErrProtocolNotSupported
+ }
+
+ // Create the endpoint and socket.
+ var ep transport.Endpoint
+ switch stype {
+ case linux.SOCK_DGRAM, linux.SOCK_RAW:
+ ep = transport.NewConnectionless(t)
+ case linux.SOCK_SEQPACKET, linux.SOCK_STREAM:
+ ep = transport.NewConnectioned(t, stype, t.Kernel())
+ default:
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ f, err := NewVFS2File(t, ep, stype)
+ if err != nil {
+ ep.Close()
+ return nil, err
+ }
+ return f, nil
+}
+
+// Pair creates a new pair of AF_UNIX connected sockets.
+func (*providerVFS2) Pair(t *kernel.Task, stype linux.SockType, protocol int) (*vfs.FileDescription, *vfs.FileDescription, *syserr.Error) {
+ // Check arguments.
+ if protocol != 0 && protocol != linux.AF_UNIX /* PF_UNIX */ {
+ return nil, nil, syserr.ErrProtocolNotSupported
+ }
+
+ switch stype {
+ case linux.SOCK_STREAM, linux.SOCK_DGRAM, linux.SOCK_SEQPACKET, linux.SOCK_RAW:
+ // Ok
+ default:
+ return nil, nil, syserr.ErrInvalidArgument
+ }
+
+ // Create the endpoints and sockets.
+ ep1, ep2 := transport.NewPair(t, stype, t.Kernel())
+ s1, err := NewVFS2File(t, ep1, stype)
+ if err != nil {
+ ep1.Close()
+ ep2.Close()
+ return nil, nil, err
+ }
+ s2, err := NewVFS2File(t, ep2, stype)
+ if err != nil {
+ s1.DecRef()
+ ep2.Close()
+ return nil, nil, err
+ }
+
+ return s1, s2, nil
+}