summaryrefslogtreecommitdiffhomepage
path: root/pkg/sentry/socket
diff options
context:
space:
mode:
Diffstat (limited to 'pkg/sentry/socket')
-rw-r--r--pkg/sentry/socket/BUILD2
-rw-r--r--pkg/sentry/socket/hostinet/BUILD4
-rw-r--r--pkg/sentry/socket/hostinet/socket.go4
-rw-r--r--pkg/sentry/socket/hostinet/socket_vfs2.go5
-rw-r--r--pkg/sentry/socket/hostinet/stack.go15
-rw-r--r--pkg/sentry/socket/netfilter/BUILD2
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go260
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go265
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go313
-rw-r--r--pkg/sentry/socket/netfilter/targets.go10
-rw-r--r--pkg/sentry/socket/netfilter/tcp_matcher.go32
-rw-r--r--pkg/sentry/socket/netfilter/udp_matcher.go32
-rw-r--r--pkg/sentry/socket/netlink/BUILD4
-rw-r--r--pkg/sentry/socket/netlink/socket.go4
-rw-r--r--pkg/sentry/socket/netstack/BUILD4
-rw-r--r--pkg/sentry/socket/netstack/netstack.go366
-rw-r--r--pkg/sentry/socket/netstack/netstack_vfs2.go80
-rw-r--r--pkg/sentry/socket/netstack/stack.go28
-rw-r--r--pkg/sentry/socket/socket.go2
-rw-r--r--pkg/sentry/socket/unix/BUILD2
-rw-r--r--pkg/sentry/socket/unix/transport/BUILD12
-rw-r--r--pkg/sentry/socket/unix/transport/connectioned.go22
-rw-r--r--pkg/sentry/socket/unix/transport/connectionless.go6
-rw-r--r--pkg/sentry/socket/unix/transport/queue.go13
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go45
-rw-r--r--pkg/sentry/socket/unix/unix.go26
-rw-r--r--pkg/sentry/socket/unix/unix_vfs2.go26
27 files changed, 1036 insertions, 548 deletions
diff --git a/pkg/sentry/socket/BUILD b/pkg/sentry/socket/BUILD
index c0fd3425b..a3f775d15 100644
--- a/pkg/sentry/socket/BUILD
+++ b/pkg/sentry/socket/BUILD
@@ -10,6 +10,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/binary",
"//pkg/context",
+ "//pkg/marshal",
"//pkg/sentry/device",
"//pkg/sentry/fs",
"//pkg/sentry/fs/fsutil",
@@ -20,6 +21,5 @@ go_library(
"//pkg/syserr",
"//pkg/tcpip",
"//pkg/usermem",
- "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/hostinet/BUILD b/pkg/sentry/socket/hostinet/BUILD
index 8448ea401..b6ebe29d6 100644
--- a/pkg/sentry/socket/hostinet/BUILD
+++ b/pkg/sentry/socket/hostinet/BUILD
@@ -21,6 +21,8 @@ go_library(
"//pkg/context",
"//pkg/fdnotifier",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/safemem",
"//pkg/sentry/arch",
"//pkg/sentry/device",
@@ -43,8 +45,6 @@ go_library(
"//pkg/tcpip/stack",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
- "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/hostinet/socket.go b/pkg/sentry/socket/hostinet/socket.go
index 242e6bf76..7d3c4a01c 100644
--- a/pkg/sentry/socket/hostinet/socket.go
+++ b/pkg/sentry/socket/hostinet/socket.go
@@ -24,6 +24,8 @@ import (
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fdnotifier"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -36,8 +38,6 @@ import (
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const (
diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go
index 8a1d52ebf..97bc6027f 100644
--- a/pkg/sentry/socket/hostinet/socket_vfs2.go
+++ b/pkg/sentry/socket/hostinet/socket_vfs2.go
@@ -97,11 +97,6 @@ func (s *socketVFS2) Ioctl(ctx context.Context, uio usermem.IO, args arch.Syscal
return ioctl(ctx, s.fd, uio, args)
}
-// Allocate implements vfs.FileDescriptionImpl.Allocate.
-func (s *socketVFS2) Allocate(ctx context.Context, mode, offset, length uint64) error {
- return syserror.ENODEV
-}
-
// PRead implements vfs.FileDescriptionImpl.PRead.
func (s *socketVFS2) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) {
return 0, syserror.ESPIPE
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index 3d3fabb30..faa61160e 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -123,18 +123,11 @@ func (s *Stack) Configure() error {
s.netSNMPFile = f
}
- s.ipv4Forwarding = false
- if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv4/ip_forward"); err == nil {
- s.ipv4Forwarding = strings.TrimSpace(string(ipForwarding)) != "0"
+ s.ipv6Forwarding = false
+ if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv6/conf/all/forwarding"); err == nil {
+ s.ipv6Forwarding = strings.TrimSpace(string(ipForwarding)) != "0"
} else {
- log.Warningf("Failed to read if IPv4 forwarding is enabled, setting to false")
- }
-
- s.ipv4Forwarding = false
- if ipForwarding, err := ioutil.ReadFile("/proc/sys/net/ipv4/ip_forward"); err == nil {
- s.ipv4Forwarding = strings.TrimSpace(string(ipForwarding)) != "0"
- } else {
- log.Warningf("Failed to read if IPv4 forwarding is enabled, setting to false")
+ log.Warningf("Failed to read if ipv6 forwarding is enabled, setting to false")
}
return nil
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index 721094bbf..8aea0200f 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -6,6 +6,8 @@ go_library(
name = "netfilter",
srcs = [
"extensions.go",
+ "ipv4.go",
+ "ipv6.go",
"netfilter.go",
"owner_matcher.go",
"targets.go",
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
new file mode 100644
index 000000000..e4c55a100
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -0,0 +1,260 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// emptyIPv4Filter is for comparison with a rule's filters to determine whether
+// it is also empty. It is immutable.
+var emptyIPv4Filter = stack.IPHeaderFilter{
+ Dst: "\x00\x00\x00\x00",
+ DstMask: "\x00\x00\x00\x00",
+ Src: "\x00\x00\x00\x00",
+ SrcMask: "\x00\x00\x00\x00",
+}
+
+// convertNetstackToBinary4 converts the iptables as stored in netstack to the
+// format expected by the iptables tool. Linux stores each table as a binary
+// blob that can only be traversed by parsing a little data, reading some
+// offsets, jumping to those offsets, parsing again, etc.
+func convertNetstackToBinary4(stk *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) {
+ // The table name has to fit in the struct.
+ if linux.XT_TABLE_MAXNAMELEN < len(tablename) {
+ return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
+ }
+
+ table, ok := stk.IPTables().GetTable(tablename.String(), false)
+ if !ok {
+ return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
+ }
+
+ // Setup the info struct.
+ entries, info := getEntries4(table, tablename)
+ return entries, info, nil
+}
+
+func getEntries4(table stack.Table, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo) {
+ var info linux.IPTGetinfo
+ var entries linux.KernelIPTGetEntries
+ copy(info.Name[:], tablename[:])
+ copy(entries.Name[:], info.Name[:])
+ info.ValidHooks = table.ValidHooks()
+
+ for ruleIdx, rule := range table.Rules {
+ nflog("convert to binary: current offset: %d", entries.Size)
+
+ setHooksAndUnderflow(&info, table, entries.Size, ruleIdx)
+ // Each rule corresponds to an entry.
+ entry := linux.KernelIPTEntry{
+ Entry: linux.IPTEntry{
+ IP: linux.IPTIP{
+ Protocol: uint16(rule.Filter.Protocol),
+ },
+ NextOffset: linux.SizeOfIPTEntry,
+ TargetOffset: linux.SizeOfIPTEntry,
+ },
+ }
+ copy(entry.Entry.IP.Dst[:], rule.Filter.Dst)
+ copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask)
+ copy(entry.Entry.IP.Src[:], rule.Filter.Src)
+ copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask)
+ copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface)
+ copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
+ if rule.Filter.DstInvert {
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP
+ }
+ if rule.Filter.SrcInvert {
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP
+ }
+ if rule.Filter.OutputInterfaceInvert {
+ entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT
+ }
+
+ for _, matcher := range rule.Matchers {
+ // Serialize the matcher and add it to the
+ // entry.
+ serialized := marshalMatcher(matcher)
+ nflog("convert to binary: matcher serialized as: %v", serialized)
+ if len(serialized)%8 != 0 {
+ panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher))
+ }
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.Entry.NextOffset += uint16(len(serialized))
+ entry.Entry.TargetOffset += uint16(len(serialized))
+ }
+
+ // Serialize and append the target.
+ serialized := marshalTarget(rule.Target)
+ if len(serialized)%8 != 0 {
+ panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target))
+ }
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.Entry.NextOffset += uint16(len(serialized))
+
+ nflog("convert to binary: adding entry: %+v", entry)
+
+ entries.Size += uint32(entry.Entry.NextOffset)
+ entries.Entrytable = append(entries.Entrytable, entry)
+ info.NumEntries++
+ }
+
+ info.Size = entries.Size
+ nflog("convert to binary: finished with an marshalled size of %d", info.Size)
+ return entries, info
+}
+
+func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) {
+ nflog("set entries: setting entries in table %q", replace.Name.String())
+
+ // Convert input into a list of rules and their offsets.
+ var offset uint32
+ // offsets maps rule byte offsets to their position in table.Rules.
+ offsets := map[uint32]int{}
+ for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
+ nflog("set entries: processing entry at offset %d", offset)
+
+ // Get the struct ipt_entry.
+ if len(optVal) < linux.SizeOfIPTEntry {
+ nflog("optVal has insufficient size for entry %d", len(optVal))
+ return nil, syserr.ErrInvalidArgument
+ }
+ var entry linux.IPTEntry
+ buf := optVal[:linux.SizeOfIPTEntry]
+ binary.Unmarshal(buf, usermem.ByteOrder, &entry)
+ initialOptValLen := len(optVal)
+ optVal = optVal[linux.SizeOfIPTEntry:]
+
+ if entry.TargetOffset < linux.SizeOfIPTEntry {
+ nflog("entry has too-small target offset %d", entry.TargetOffset)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): We should support more IPTIP
+ // filtering fields.
+ filter, err := filterFromIPTIP(entry.IP)
+ if err != nil {
+ nflog("bad iptip: %v", err)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): Matchers and targets can specify
+ // that they only work for certain protocols, hooks, tables.
+ // Get matchers.
+ matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry
+ if len(optVal) < int(matchersSize) {
+ nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal))
+ return nil, syserr.ErrInvalidArgument
+ }
+ matchers, err := parseMatchers(filter, optVal[:matchersSize])
+ if err != nil {
+ nflog("failed to parse matchers: %v", err)
+ return nil, syserr.ErrInvalidArgument
+ }
+ optVal = optVal[matchersSize:]
+
+ // Get the target of the rule.
+ targetSize := entry.NextOffset - entry.TargetOffset
+ if len(optVal) < int(targetSize) {
+ nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
+ return nil, syserr.ErrInvalidArgument
+ }
+ target, err := parseTarget(filter, optVal[:targetSize])
+ if err != nil {
+ nflog("failed to parse target: %v", err)
+ return nil, syserr.ErrInvalidArgument
+ }
+ optVal = optVal[targetSize:]
+
+ table.Rules = append(table.Rules, stack.Rule{
+ Filter: filter,
+ Target: target,
+ Matchers: matchers,
+ })
+ offsets[offset] = int(entryIdx)
+ offset += uint32(entry.NextOffset)
+
+ if initialOptValLen-len(optVal) != int(entry.NextOffset) {
+ nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal))
+ return nil, syserr.ErrInvalidArgument
+ }
+ }
+ return offsets, nil
+}
+
+func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
+ if containsUnsupportedFields4(iptip) {
+ return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
+ }
+ if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask))
+ }
+ if len(iptip.Src) != header.IPv4AddressSize || len(iptip.SrcMask) != header.IPv4AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask))
+ }
+
+ n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0)
+ if n == -1 {
+ n = len(iptip.OutputInterface)
+ }
+ ifname := string(iptip.OutputInterface[:n])
+
+ n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0)
+ if n == -1 {
+ n = len(iptip.OutputInterfaceMask)
+ }
+ ifnameMask := string(iptip.OutputInterfaceMask[:n])
+
+ return stack.IPHeaderFilter{
+ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ // A Protocol value of 0 indicates all protocols match.
+ CheckProtocol: iptip.Protocol != 0,
+ Dst: tcpip.Address(iptip.Dst[:]),
+ DstMask: tcpip.Address(iptip.DstMask[:]),
+ DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0,
+ Src: tcpip.Address(iptip.Src[:]),
+ SrcMask: tcpip.Address(iptip.SrcMask[:]),
+ SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0,
+ OutputInterface: ifname,
+ OutputInterfaceMask: ifnameMask,
+ OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0,
+ }, nil
+}
+
+func containsUnsupportedFields4(iptip linux.IPTIP) bool {
+ // The following features are supported:
+ // - Protocol
+ // - Dst and DstMask
+ // - Src and SrcMask
+ // - The inverse destination IP check flag
+ // - OutputInterface, OutputInterfaceMask and its inverse.
+ var emptyInterface = [linux.IFNAMSIZ]byte{}
+ // Disable any supported inverse flags.
+ inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT)
+ return iptip.InputInterface != emptyInterface ||
+ iptip.InputInterfaceMask != emptyInterface ||
+ iptip.Flags != 0 ||
+ iptip.InverseFlags&^inverseMask != 0
+}
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
new file mode 100644
index 000000000..3b2c1becd
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -0,0 +1,265 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// emptyIPv6Filter is for comparison with a rule's filters to determine whether
+// it is also empty. It is immutable.
+var emptyIPv6Filter = stack.IPHeaderFilter{
+ Dst: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ DstMask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ Src: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ SrcMask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+}
+
+// convertNetstackToBinary6 converts the ip6tables as stored in netstack to the
+// format expected by the iptables tool. Linux stores each table as a binary
+// blob that can only be traversed by parsing a little data, reading some
+// offsets, jumping to those offsets, parsing again, etc.
+func convertNetstackToBinary6(stk *stack.Stack, tablename linux.TableName) (linux.KernelIP6TGetEntries, linux.IPTGetinfo, error) {
+ // The table name has to fit in the struct.
+ if linux.XT_TABLE_MAXNAMELEN < len(tablename) {
+ return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
+ }
+
+ table, ok := stk.IPTables().GetTable(tablename.String(), true)
+ if !ok {
+ return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
+ }
+
+ // Setup the info struct, which is the same in IPv4 and IPv6.
+ entries, info := getEntries6(table, tablename)
+ return entries, info, nil
+}
+
+func getEntries6(table stack.Table, tablename linux.TableName) (linux.KernelIP6TGetEntries, linux.IPTGetinfo) {
+ var info linux.IPTGetinfo
+ var entries linux.KernelIP6TGetEntries
+ copy(info.Name[:], tablename[:])
+ copy(entries.Name[:], info.Name[:])
+ info.ValidHooks = table.ValidHooks()
+
+ for ruleIdx, rule := range table.Rules {
+ nflog("convert to binary: current offset: %d", entries.Size)
+
+ setHooksAndUnderflow(&info, table, entries.Size, ruleIdx)
+ // Each rule corresponds to an entry.
+ entry := linux.KernelIP6TEntry{
+ Entry: linux.IP6TEntry{
+ IPv6: linux.IP6TIP{
+ Protocol: uint16(rule.Filter.Protocol),
+ },
+ NextOffset: linux.SizeOfIP6TEntry,
+ TargetOffset: linux.SizeOfIP6TEntry,
+ },
+ }
+ copy(entry.Entry.IPv6.Dst[:], rule.Filter.Dst)
+ copy(entry.Entry.IPv6.DstMask[:], rule.Filter.DstMask)
+ copy(entry.Entry.IPv6.Src[:], rule.Filter.Src)
+ copy(entry.Entry.IPv6.SrcMask[:], rule.Filter.SrcMask)
+ copy(entry.Entry.IPv6.OutputInterface[:], rule.Filter.OutputInterface)
+ copy(entry.Entry.IPv6.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
+ if rule.Filter.DstInvert {
+ entry.Entry.IPv6.InverseFlags |= linux.IP6T_INV_DSTIP
+ }
+ if rule.Filter.SrcInvert {
+ entry.Entry.IPv6.InverseFlags |= linux.IP6T_INV_SRCIP
+ }
+ if rule.Filter.OutputInterfaceInvert {
+ entry.Entry.IPv6.InverseFlags |= linux.IP6T_INV_VIA_OUT
+ }
+ if rule.Filter.CheckProtocol {
+ entry.Entry.IPv6.Flags |= linux.IP6T_F_PROTO
+ }
+
+ for _, matcher := range rule.Matchers {
+ // Serialize the matcher and add it to the
+ // entry.
+ serialized := marshalMatcher(matcher)
+ nflog("convert to binary: matcher serialized as: %v", serialized)
+ if len(serialized)%8 != 0 {
+ panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher))
+ }
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.Entry.NextOffset += uint16(len(serialized))
+ entry.Entry.TargetOffset += uint16(len(serialized))
+ }
+
+ // Serialize and append the target.
+ serialized := marshalTarget(rule.Target)
+ if len(serialized)%8 != 0 {
+ panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target))
+ }
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.Entry.NextOffset += uint16(len(serialized))
+
+ nflog("convert to binary: adding entry: %+v", entry)
+
+ entries.Size += uint32(entry.Entry.NextOffset)
+ entries.Entrytable = append(entries.Entrytable, entry)
+ info.NumEntries++
+ }
+
+ info.Size = entries.Size
+ nflog("convert to binary: finished with an marshalled size of %d", info.Size)
+ return entries, info
+}
+
+func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) {
+ nflog("set entries: setting entries in table %q", replace.Name.String())
+
+ // Convert input into a list of rules and their offsets.
+ var offset uint32
+ // offsets maps rule byte offsets to their position in table.Rules.
+ offsets := map[uint32]int{}
+ for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
+ nflog("set entries: processing entry at offset %d", offset)
+
+ // Get the struct ipt_entry.
+ if len(optVal) < linux.SizeOfIP6TEntry {
+ nflog("optVal has insufficient size for entry %d", len(optVal))
+ return nil, syserr.ErrInvalidArgument
+ }
+ var entry linux.IP6TEntry
+ buf := optVal[:linux.SizeOfIP6TEntry]
+ binary.Unmarshal(buf, usermem.ByteOrder, &entry)
+ initialOptValLen := len(optVal)
+ optVal = optVal[linux.SizeOfIP6TEntry:]
+
+ if entry.TargetOffset < linux.SizeOfIP6TEntry {
+ nflog("entry has too-small target offset %d", entry.TargetOffset)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): We should support more IPTIP
+ // filtering fields.
+ filter, err := filterFromIP6TIP(entry.IPv6)
+ if err != nil {
+ nflog("bad iptip: %v", err)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): Matchers and targets can specify
+ // that they only work for certain protocols, hooks, tables.
+ // Get matchers.
+ matchersSize := entry.TargetOffset - linux.SizeOfIP6TEntry
+ if len(optVal) < int(matchersSize) {
+ nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal))
+ return nil, syserr.ErrInvalidArgument
+ }
+ matchers, err := parseMatchers(filter, optVal[:matchersSize])
+ if err != nil {
+ nflog("failed to parse matchers: %v", err)
+ return nil, syserr.ErrInvalidArgument
+ }
+ optVal = optVal[matchersSize:]
+
+ // Get the target of the rule.
+ targetSize := entry.NextOffset - entry.TargetOffset
+ if len(optVal) < int(targetSize) {
+ nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
+ return nil, syserr.ErrInvalidArgument
+ }
+ target, err := parseTarget(filter, optVal[:targetSize])
+ if err != nil {
+ nflog("failed to parse target: %v", err)
+ return nil, syserr.ErrInvalidArgument
+ }
+ optVal = optVal[targetSize:]
+
+ table.Rules = append(table.Rules, stack.Rule{
+ Filter: filter,
+ Target: target,
+ Matchers: matchers,
+ })
+ offsets[offset] = int(entryIdx)
+ offset += uint32(entry.NextOffset)
+
+ if initialOptValLen-len(optVal) != int(entry.NextOffset) {
+ nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal))
+ return nil, syserr.ErrInvalidArgument
+ }
+ }
+ return offsets, nil
+}
+
+func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) {
+ if containsUnsupportedFields6(iptip) {
+ return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
+ }
+ if len(iptip.Dst) != header.IPv6AddressSize || len(iptip.DstMask) != header.IPv6AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask))
+ }
+ if len(iptip.Src) != header.IPv6AddressSize || len(iptip.SrcMask) != header.IPv6AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask))
+ }
+
+ n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0)
+ if n == -1 {
+ n = len(iptip.OutputInterface)
+ }
+ ifname := string(iptip.OutputInterface[:n])
+
+ n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0)
+ if n == -1 {
+ n = len(iptip.OutputInterfaceMask)
+ }
+ ifnameMask := string(iptip.OutputInterfaceMask[:n])
+
+ return stack.IPHeaderFilter{
+ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ // In ip6tables a flag controls whether to check the protocol.
+ CheckProtocol: iptip.Flags&linux.IP6T_F_PROTO != 0,
+ Dst: tcpip.Address(iptip.Dst[:]),
+ DstMask: tcpip.Address(iptip.DstMask[:]),
+ DstInvert: iptip.InverseFlags&linux.IP6T_INV_DSTIP != 0,
+ Src: tcpip.Address(iptip.Src[:]),
+ SrcMask: tcpip.Address(iptip.SrcMask[:]),
+ SrcInvert: iptip.InverseFlags&linux.IP6T_INV_SRCIP != 0,
+ OutputInterface: ifname,
+ OutputInterfaceMask: ifnameMask,
+ OutputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_OUT != 0,
+ }, nil
+}
+
+func containsUnsupportedFields6(iptip linux.IP6TIP) bool {
+ // The following features are supported:
+ // - Protocol
+ // - Dst and DstMask
+ // - Src and SrcMask
+ // - The inverse destination IP check flag
+ // - OutputInterface, OutputInterfaceMask and its inverse.
+ var emptyInterface = [linux.IFNAMSIZ]byte{}
+ flagMask := uint8(linux.IP6T_F_PROTO)
+ // Disable any supported inverse flags.
+ inverseMask := uint8(linux.IP6T_INV_DSTIP) | uint8(linux.IP6T_INV_SRCIP) | uint8(linux.IP6T_INV_VIA_OUT)
+ return iptip.InputInterface != emptyInterface ||
+ iptip.InputInterfaceMask != emptyInterface ||
+ iptip.Flags&^flagMask != 0 ||
+ iptip.InverseFlags&^inverseMask != 0 ||
+ iptip.TOS != 0
+}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index e91b0624c..871ea80ee 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -17,7 +17,6 @@
package netfilter
import (
- "bytes"
"errors"
"fmt"
@@ -26,8 +25,6 @@ import (
"gvisor.dev/gvisor/pkg/log"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/syserr"
- "gvisor.dev/gvisor/pkg/tcpip"
- "gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/stack"
"gvisor.dev/gvisor/pkg/usermem"
)
@@ -37,15 +34,6 @@ import (
// developing iptables, but can pollute sentry logs otherwise.
const enableLogging = false
-// emptyFilter is for comparison with a rule's filters to determine whether it
-// is also empty. It is immutable.
-var emptyFilter = stack.IPHeaderFilter{
- Dst: "\x00\x00\x00\x00",
- DstMask: "\x00\x00\x00\x00",
- Src: "\x00\x00\x00\x00",
- SrcMask: "\x00\x00\x00\x00",
-}
-
// nflog logs messages related to the writing and reading of iptables.
func nflog(format string, args ...interface{}) {
if enableLogging && log.IsLogging(log.Debug) {
@@ -54,14 +42,19 @@ func nflog(format string, args ...interface{}) {
}
// GetInfo returns information about iptables.
-func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) {
+func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) {
// Read in the struct and table name.
var info linux.IPTGetinfo
if _, err := info.CopyIn(t, outPtr); err != nil {
return linux.IPTGetinfo{}, syserr.FromError(err)
}
- _, info, err := convertNetstackToBinary(stack, info.Name)
+ var err error
+ if ipv6 {
+ _, info, err = convertNetstackToBinary6(stack, info.Name)
+ } else {
+ _, info, err = convertNetstackToBinary4(stack, info.Name)
+ }
if err != nil {
nflog("couldn't convert iptables: %v", err)
return linux.IPTGetinfo{}, syserr.ErrInvalidArgument
@@ -71,8 +64,8 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT
return info, nil
}
-// GetEntries returns netstack's iptables rules encoded for the iptables tool.
-func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) {
+// GetEntries4 returns netstack's iptables rules.
+func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) {
// Read in the struct and table name.
var userEntries linux.IPTGetEntries
if _, err := userEntries.CopyIn(t, outPtr); err != nil {
@@ -82,7 +75,7 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen
// Convert netstack's iptables rules to something that the iptables
// tool can understand.
- entries, _, err := convertNetstackToBinary(stack, userEntries.Name)
+ entries, _, err := convertNetstackToBinary4(stack, userEntries.Name)
if err != nil {
nflog("couldn't read entries: %v", err)
return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
@@ -95,112 +88,53 @@ func GetEntries(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen
return entries, nil
}
-// convertNetstackToBinary converts the iptables as stored in netstack to the
-// format expected by the iptables tool. Linux stores each table as a binary
-// blob that can only be traversed by parsing a bit, reading some offsets,
-// jumping to those offsets, parsing again, etc.
-func convertNetstackToBinary(stack *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) {
- table, ok := stack.IPTables().GetTable(tablename.String())
- if !ok {
- return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
+// GetEntries6 returns netstack's ip6tables rules.
+func GetEntries6(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIP6TGetEntries, *syserr.Error) {
+ // Read in the struct and table name. IPv4 and IPv6 utilize structs
+ // with the same layout.
+ var userEntries linux.IPTGetEntries
+ if _, err := userEntries.CopyIn(t, outPtr); err != nil {
+ nflog("couldn't copy in entries %q", userEntries.Name)
+ return linux.KernelIP6TGetEntries{}, syserr.FromError(err)
}
- var entries linux.KernelIPTGetEntries
- var info linux.IPTGetinfo
- info.ValidHooks = table.ValidHooks()
-
- // The table name has to fit in the struct.
- if linux.XT_TABLE_MAXNAMELEN < len(tablename) {
- return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
+ // Convert netstack's iptables rules to something that the iptables
+ // tool can understand.
+ entries, _, err := convertNetstackToBinary6(stack, userEntries.Name)
+ if err != nil {
+ nflog("couldn't read entries: %v", err)
+ return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument
+ }
+ if binary.Size(entries) > uintptr(outLen) {
+ nflog("insufficient GetEntries output size: %d", uintptr(outLen))
+ return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument
}
- copy(info.Name[:], tablename[:])
- copy(entries.Name[:], tablename[:])
-
- for ruleIdx, rule := range table.Rules {
- nflog("convert to binary: current offset: %d", entries.Size)
-
- // Is this a chain entry point?
- for hook, hookRuleIdx := range table.BuiltinChains {
- if hookRuleIdx == ruleIdx {
- nflog("convert to binary: found hook %d at offset %d", hook, entries.Size)
- info.HookEntry[hook] = entries.Size
- }
- }
- // Is this a chain underflow point?
- for underflow, underflowRuleIdx := range table.Underflows {
- if underflowRuleIdx == ruleIdx {
- nflog("convert to binary: found underflow %d at offset %d", underflow, entries.Size)
- info.Underflow[underflow] = entries.Size
- }
- }
- // Each rule corresponds to an entry.
- entry := linux.KernelIPTEntry{
- Entry: linux.IPTEntry{
- IP: linux.IPTIP{
- Protocol: uint16(rule.Filter.Protocol),
- },
- NextOffset: linux.SizeOfIPTEntry,
- TargetOffset: linux.SizeOfIPTEntry,
- },
- }
- copy(entry.Entry.IP.Dst[:], rule.Filter.Dst)
- copy(entry.Entry.IP.DstMask[:], rule.Filter.DstMask)
- copy(entry.Entry.IP.Src[:], rule.Filter.Src)
- copy(entry.Entry.IP.SrcMask[:], rule.Filter.SrcMask)
- copy(entry.Entry.IP.OutputInterface[:], rule.Filter.OutputInterface)
- copy(entry.Entry.IP.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
- if rule.Filter.DstInvert {
- entry.Entry.IP.InverseFlags |= linux.IPT_INV_DSTIP
- }
- if rule.Filter.SrcInvert {
- entry.Entry.IP.InverseFlags |= linux.IPT_INV_SRCIP
- }
- if rule.Filter.OutputInterfaceInvert {
- entry.Entry.IP.InverseFlags |= linux.IPT_INV_VIA_OUT
- }
+ return entries, nil
+}
- for _, matcher := range rule.Matchers {
- // Serialize the matcher and add it to the
- // entry.
- serialized := marshalMatcher(matcher)
- nflog("convert to binary: matcher serialized as: %v", serialized)
- if len(serialized)%8 != 0 {
- panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher))
- }
- entry.Elems = append(entry.Elems, serialized...)
- entry.Entry.NextOffset += uint16(len(serialized))
- entry.Entry.TargetOffset += uint16(len(serialized))
+// setHooksAndUnderflow checks whether the rule at ruleIdx is a hook entrypoint
+// or underflow, in which case it fills in info.HookEntry and info.Underflows.
+func setHooksAndUnderflow(info *linux.IPTGetinfo, table stack.Table, offset uint32, ruleIdx int) {
+ // Is this a chain entry point?
+ for hook, hookRuleIdx := range table.BuiltinChains {
+ if hookRuleIdx == ruleIdx {
+ nflog("convert to binary: found hook %d at offset %d", hook, offset)
+ info.HookEntry[hook] = offset
}
-
- // Serialize and append the target.
- serialized := marshalTarget(rule.Target)
- if len(serialized)%8 != 0 {
- panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target))
+ }
+ // Is this a chain underflow point?
+ for underflow, underflowRuleIdx := range table.Underflows {
+ if underflowRuleIdx == ruleIdx {
+ nflog("convert to binary: found underflow %d at offset %d", underflow, offset)
+ info.Underflow[underflow] = offset
}
- entry.Elems = append(entry.Elems, serialized...)
- entry.Entry.NextOffset += uint16(len(serialized))
-
- nflog("convert to binary: adding entry: %+v", entry)
-
- entries.Size += uint32(entry.Entry.NextOffset)
- entries.Entrytable = append(entries.Entrytable, entry)
- info.NumEntries++
}
-
- nflog("convert to binary: finished with an marshalled size of %d", info.Size)
- info.Size = entries.Size
- return entries, info, nil
}
// SetEntries sets iptables rules for a single table. See
// net/ipv4/netfilter/ip_tables.c:translate_table for reference.
-func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
- // Get the basic rules data (struct ipt_replace).
- if len(optVal) < linux.SizeOfIPTReplace {
- nflog("optVal has insufficient size for replace %d", len(optVal))
- return syserr.ErrInvalidArgument
- }
+func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
var replace linux.IPTReplace
replaceBuf := optVal[:linux.SizeOfIPTReplace]
optVal = optVal[linux.SizeOfIPTReplace:]
@@ -212,85 +146,25 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
case stack.FilterTable:
table = stack.EmptyFilterTable()
case stack.NATTable:
+ if ipv6 {
+ nflog("IPv6 redirection not yet supported (gvisor.dev/issue/3549)")
+ return syserr.ErrInvalidArgument
+ }
table = stack.EmptyNATTable()
default:
nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String())
return syserr.ErrInvalidArgument
}
- nflog("set entries: setting entries in table %q", replace.Name.String())
-
- // Convert input into a list of rules and their offsets.
- var offset uint32
- // offsets maps rule byte offsets to their position in table.Rules.
- offsets := map[uint32]int{}
- for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
- nflog("set entries: processing entry at offset %d", offset)
-
- // Get the struct ipt_entry.
- if len(optVal) < linux.SizeOfIPTEntry {
- nflog("optVal has insufficient size for entry %d", len(optVal))
- return syserr.ErrInvalidArgument
- }
- var entry linux.IPTEntry
- buf := optVal[:linux.SizeOfIPTEntry]
- binary.Unmarshal(buf, usermem.ByteOrder, &entry)
- initialOptValLen := len(optVal)
- optVal = optVal[linux.SizeOfIPTEntry:]
-
- if entry.TargetOffset < linux.SizeOfIPTEntry {
- nflog("entry has too-small target offset %d", entry.TargetOffset)
- return syserr.ErrInvalidArgument
- }
-
- // TODO(gvisor.dev/issue/170): We should support more IPTIP
- // filtering fields.
- filter, err := filterFromIPTIP(entry.IP)
- if err != nil {
- nflog("bad iptip: %v", err)
- return syserr.ErrInvalidArgument
- }
-
- // TODO(gvisor.dev/issue/170): Matchers and targets can specify
- // that they only work for certain protocols, hooks, tables.
- // Get matchers.
- matchersSize := entry.TargetOffset - linux.SizeOfIPTEntry
- if len(optVal) < int(matchersSize) {
- nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal))
- return syserr.ErrInvalidArgument
- }
- matchers, err := parseMatchers(filter, optVal[:matchersSize])
- if err != nil {
- nflog("failed to parse matchers: %v", err)
- return syserr.ErrInvalidArgument
- }
- optVal = optVal[matchersSize:]
-
- // Get the target of the rule.
- targetSize := entry.NextOffset - entry.TargetOffset
- if len(optVal) < int(targetSize) {
- nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
- return syserr.ErrInvalidArgument
- }
- target, err := parseTarget(filter, optVal[:targetSize])
- if err != nil {
- nflog("failed to parse target: %v", err)
- return syserr.ErrInvalidArgument
- }
- optVal = optVal[targetSize:]
-
- table.Rules = append(table.Rules, stack.Rule{
- Filter: filter,
- Target: target,
- Matchers: matchers,
- })
- offsets[offset] = int(entryIdx)
- offset += uint32(entry.NextOffset)
-
- if initialOptValLen-len(optVal) != int(entry.NextOffset) {
- nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal))
- return syserr.ErrInvalidArgument
- }
+ var err *syserr.Error
+ var offsets map[uint32]int
+ if ipv6 {
+ offsets, err = modifyEntries6(stk, optVal, &replace, &table)
+ } else {
+ offsets, err = modifyEntries4(stk, optVal, &replace, &table)
+ }
+ if err != nil {
+ return err
}
// Go through the list of supported hooks for this table and, for each
@@ -305,7 +179,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
table.BuiltinChains[hk] = ruleIdx
}
if offset == replace.Underflow[hook] {
- if !validUnderflow(table.Rules[ruleIdx]) {
+ if !validUnderflow(table.Rules[ruleIdx], ipv6) {
nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx)
return syserr.ErrInvalidArgument
}
@@ -323,7 +197,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
}
}
- // Add the user chains.
+ // Check the user chains.
for ruleIdx, rule := range table.Rules {
if _, ok := rule.Target.(stack.UserChainTarget); !ok {
continue
@@ -370,7 +244,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
if ruleIdx == stack.HookUnset {
continue
}
- if !isUnconditionalAccept(table.Rules[ruleIdx]) {
+ if !isUnconditionalAccept(table.Rules[ruleIdx], ipv6) {
nflog("hook %d is unsupported.", hook)
return syserr.ErrInvalidArgument
}
@@ -382,7 +256,8 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
- return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table))
+ return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table, ipv6))
+
}
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
@@ -404,7 +279,6 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher,
// Check some invariants.
if match.MatchSize < linux.SizeOfXTEntryMatch {
-
return nil, fmt.Errorf("match size is too small, must be at least %d", linux.SizeOfXTEntryMatch)
}
if len(optVal) < int(match.MatchSize) {
@@ -429,64 +303,11 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher,
return matchers, nil
}
-func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
- if containsUnsupportedFields(iptip) {
- return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
- }
- if len(iptip.Dst) != header.IPv4AddressSize || len(iptip.DstMask) != header.IPv4AddressSize {
- return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask))
- }
- if len(iptip.Src) != header.IPv4AddressSize || len(iptip.SrcMask) != header.IPv4AddressSize {
- return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask))
- }
-
- n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0)
- if n == -1 {
- n = len(iptip.OutputInterface)
- }
- ifname := string(iptip.OutputInterface[:n])
-
- n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0)
- if n == -1 {
- n = len(iptip.OutputInterfaceMask)
- }
- ifnameMask := string(iptip.OutputInterfaceMask[:n])
-
- return stack.IPHeaderFilter{
- Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
- Dst: tcpip.Address(iptip.Dst[:]),
- DstMask: tcpip.Address(iptip.DstMask[:]),
- DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0,
- Src: tcpip.Address(iptip.Src[:]),
- SrcMask: tcpip.Address(iptip.SrcMask[:]),
- SrcInvert: iptip.InverseFlags&linux.IPT_INV_SRCIP != 0,
- OutputInterface: ifname,
- OutputInterfaceMask: ifnameMask,
- OutputInterfaceInvert: iptip.InverseFlags&linux.IPT_INV_VIA_OUT != 0,
- }, nil
-}
-
-func containsUnsupportedFields(iptip linux.IPTIP) bool {
- // The following features are supported:
- // - Protocol
- // - Dst and DstMask
- // - Src and SrcMask
- // - The inverse destination IP check flag
- // - OutputInterface, OutputInterfaceMask and its inverse.
- var emptyInterface = [linux.IFNAMSIZ]byte{}
- // Disable any supported inverse flags.
- inverseMask := uint8(linux.IPT_INV_DSTIP) | uint8(linux.IPT_INV_SRCIP) | uint8(linux.IPT_INV_VIA_OUT)
- return iptip.InputInterface != emptyInterface ||
- iptip.InputInterfaceMask != emptyInterface ||
- iptip.Flags != 0 ||
- iptip.InverseFlags&^inverseMask != 0
-}
-
-func validUnderflow(rule stack.Rule) bool {
+func validUnderflow(rule stack.Rule, ipv6 bool) bool {
if len(rule.Matchers) != 0 {
return false
}
- if rule.Filter != emptyFilter {
+ if (ipv6 && rule.Filter != emptyIPv6Filter) || (!ipv6 && rule.Filter != emptyIPv4Filter) {
return false
}
switch rule.Target.(type) {
@@ -497,8 +318,8 @@ func validUnderflow(rule stack.Rule) bool {
}
}
-func isUnconditionalAccept(rule stack.Rule) bool {
- if !validUnderflow(rule) {
+func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool {
+ if !validUnderflow(rule, ipv6) {
return false
}
_, ok := rule.Target.(stack.AcceptTarget)
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index 8ebdaff18..87e41abd8 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -218,8 +218,8 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, erro
return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal))
}
- if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber {
- return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
+ return nil, fmt.Errorf("netfilter.SetEntries: bad proto %d", p)
}
var redirectTarget linux.XTRedirectTarget
@@ -232,7 +232,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, erro
// RangeSize should be 1.
if nfRange.RangeSize != 1 {
- return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ return nil, fmt.Errorf("netfilter.SetEntries: bad rangesize %d", nfRange.RangeSize)
}
// TODO(gvisor.dev/issue/170): Check if the flags are valid.
@@ -240,7 +240,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, erro
// For now, redirect target only supports destination port change.
// Port range and IP range are not supported yet.
if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 {
- return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid range flags %d", nfRange.RangeIPV4.Flags)
}
target.RangeProtoSpecified = true
@@ -249,7 +249,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, erro
// TODO(gvisor.dev/issue/170): Port range is not supported yet.
if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
- return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ return nil, fmt.Errorf("netfilter.SetEntries: minport != maxport (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
}
// Convert port from big endian to little endian.
diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go
index 0bfd6c1f4..844acfede 100644
--- a/pkg/sentry/socket/netfilter/tcp_matcher.go
+++ b/pkg/sentry/socket/netfilter/tcp_matcher.go
@@ -97,17 +97,33 @@ func (*TCPMatcher) Name() string {
// Match implements Matcher.Match.
func (tm *TCPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
+ // TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
+ // into the stack.Check codepath as matchers are added.
+ switch pkt.NetworkProtocolNumber {
+ case header.IPv4ProtocolNumber:
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ if netHeader.TransportProtocol() != header.TCPProtocolNumber {
+ return false, false
+ }
- if netHeader.TransportProtocol() != header.TCPProtocolNumber {
- return false, false
- }
+ // We don't match fragments.
+ if frag := netHeader.FragmentOffset(); frag != 0 {
+ if frag == 1 {
+ return false, true
+ }
+ return false, false
+ }
- // We dont't match fragments.
- if frag := netHeader.FragmentOffset(); frag != 0 {
- if frag == 1 {
- return false, true
+ case header.IPv6ProtocolNumber:
+ // As in Linux, we do not perform an IPv6 fragment check. See
+ // xt_action_param.fragoff in
+ // include/linux/netfilter/x_tables.h.
+ if header.IPv6(pkt.NetworkHeader().View()).TransportProtocol() != header.TCPProtocolNumber {
+ return false, false
}
+
+ default:
+ // We don't know the network protocol.
return false, false
}
diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go
index 7ed05461d..63201201c 100644
--- a/pkg/sentry/socket/netfilter/udp_matcher.go
+++ b/pkg/sentry/socket/netfilter/udp_matcher.go
@@ -94,19 +94,33 @@ func (*UDPMatcher) Name() string {
// Match implements Matcher.Match.
func (um *UDPMatcher) Match(hook stack.Hook, pkt *stack.PacketBuffer, interfaceName string) (bool, bool) {
- netHeader := header.IPv4(pkt.NetworkHeader().View())
-
// TODO(gvisor.dev/issue/170): Proto checks should ultimately be moved
// into the stack.Check codepath as matchers are added.
- if netHeader.TransportProtocol() != header.UDPProtocolNumber {
- return false, false
- }
+ switch pkt.NetworkProtocolNumber {
+ case header.IPv4ProtocolNumber:
+ netHeader := header.IPv4(pkt.NetworkHeader().View())
+ if netHeader.TransportProtocol() != header.UDPProtocolNumber {
+ return false, false
+ }
- // We dont't match fragments.
- if frag := netHeader.FragmentOffset(); frag != 0 {
- if frag == 1 {
- return false, true
+ // We don't match fragments.
+ if frag := netHeader.FragmentOffset(); frag != 0 {
+ if frag == 1 {
+ return false, true
+ }
+ return false, false
}
+
+ case header.IPv6ProtocolNumber:
+ // As in Linux, we do not perform an IPv6 fragment check. See
+ // xt_action_param.fragoff in
+ // include/linux/netfilter/x_tables.h.
+ if header.IPv6(pkt.NetworkHeader().View()).TransportProtocol() != header.UDPProtocolNumber {
+ return false, false
+ }
+
+ default:
+ // We don't know the network protocol.
return false, false
}
diff --git a/pkg/sentry/socket/netlink/BUILD b/pkg/sentry/socket/netlink/BUILD
index 0546801bf..1f926aa91 100644
--- a/pkg/sentry/socket/netlink/BUILD
+++ b/pkg/sentry/socket/netlink/BUILD
@@ -16,6 +16,8 @@ go_library(
"//pkg/abi/linux",
"//pkg/binary",
"//pkg/context",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/sentry/arch",
"//pkg/sentry/device",
"//pkg/sentry/fs",
@@ -36,8 +38,6 @@ go_library(
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
- "//tools/go_marshal/primitive",
],
)
diff --git a/pkg/sentry/socket/netlink/socket.go b/pkg/sentry/socket/netlink/socket.go
index 68a9b9a96..5ddcd4be5 100644
--- a/pkg/sentry/socket/netlink/socket.go
+++ b/pkg/sentry/socket/netlink/socket.go
@@ -21,6 +21,8 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -38,8 +40,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
const sizeOfInt32 int = 4
diff --git a/pkg/sentry/socket/netstack/BUILD b/pkg/sentry/socket/netstack/BUILD
index 1fb777a6c..fae3b6783 100644
--- a/pkg/sentry/socket/netstack/BUILD
+++ b/pkg/sentry/socket/netstack/BUILD
@@ -22,6 +22,8 @@ go_library(
"//pkg/binary",
"//pkg/context",
"//pkg/log",
+ "//pkg/marshal",
+ "//pkg/marshal/primitive",
"//pkg/metric",
"//pkg/safemem",
"//pkg/sentry/arch",
@@ -51,8 +53,6 @@ go_library(
"//pkg/tcpip/transport/udp",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
- "//tools/go_marshal/primitive",
"@org_golang_x_sys//unix:go_default_library",
],
)
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index e4846bc0b..6fede181a 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -40,6 +40,8 @@ import (
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/log"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/metric"
"gvisor.dev/gvisor/pkg/safemem"
"gvisor.dev/gvisor/pkg/sentry/arch"
@@ -62,8 +64,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
func mustCreateMetric(name, description string) *tcpip.StatCounter {
@@ -158,6 +158,9 @@ var Metrics = tcpip.Stats{
OutgoingPacketErrors: mustCreateMetric("/netstack/ip/outgoing_packet_errors", "Total number of IP packets which failed to write to a link-layer endpoint."),
MalformedPacketsReceived: mustCreateMetric("/netstack/ip/malformed_packets_received", "Total number of IP packets which failed IP header validation checks."),
MalformedFragmentsReceived: mustCreateMetric("/netstack/ip/malformed_fragments_received", "Total number of IP fragments which failed IP fragment validation checks."),
+ IPTablesPreroutingDropped: mustCreateMetric("/netstack/ip/iptables/prerouting_dropped", "Total number of IP packets dropped in the Prerouting chain."),
+ IPTablesInputDropped: mustCreateMetric("/netstack/ip/iptables/input_dropped", "Total number of IP packets dropped in the Input chain."),
+ IPTablesOutputDropped: mustCreateMetric("/netstack/ip/iptables/output_dropped", "Total number of IP packets dropped in the Output chain."),
},
TCP: tcpip.TCPStats{
ActiveConnectionOpenings: mustCreateMetric("/netstack/tcp/active_connection_openings", "Number of connections opened successfully via Connect."),
@@ -236,7 +239,7 @@ type commonEndpoint interface {
// SetSockOpt implements tcpip.Endpoint.SetSockOpt and
// transport.Endpoint.SetSockOpt.
- SetSockOpt(interface{}) *tcpip.Error
+ SetSockOpt(tcpip.SettableSocketOption) *tcpip.Error
// SetSockOptBool implements tcpip.Endpoint.SetSockOptBool and
// transport.Endpoint.SetSockOptBool.
@@ -248,7 +251,7 @@ type commonEndpoint interface {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt and
// transport.Endpoint.GetSockOpt.
- GetSockOpt(interface{}) *tcpip.Error
+ GetSockOpt(tcpip.GettableSocketOption) *tcpip.Error
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool and
// transport.Endpoint.GetSockOpt.
@@ -257,6 +260,9 @@ type commonEndpoint interface {
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
// transport.Endpoint.GetSockOpt.
GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
+
+ // LastError implements tcpip.Endpoint.LastError.
+ LastError() *tcpip.Error
}
// LINT.IfChange
@@ -479,8 +485,35 @@ func (s *socketOpsCommon) fetchReadView() *syserr.Error {
}
// Release implements fs.FileOperations.Release.
-func (s *socketOpsCommon) Release(context.Context) {
+func (s *socketOpsCommon) Release(ctx context.Context) {
+ e, ch := waiter.NewChannelEntry(nil)
+ s.EventRegister(&e, waiter.EventHUp|waiter.EventErr)
+ defer s.EventUnregister(&e)
+
s.Endpoint.Close()
+
+ // SO_LINGER option is valid only for TCP. For other socket types
+ // return after endpoint close.
+ if family, skType, _ := s.Type(); skType != linux.SOCK_STREAM || (family != linux.AF_INET && family != linux.AF_INET6) {
+ return
+ }
+
+ var v tcpip.LingerOption
+ if err := s.Endpoint.GetSockOpt(&v); err != nil {
+ return
+ }
+
+ // The case for zero timeout is handled in tcp endpoint close function.
+ // Close is blocked until either:
+ // 1. The endpoint state is not in any of the states: FIN-WAIT1,
+ // CLOSING and LAST_ACK.
+ // 2. Timeout is reached.
+ if v.Enabled && v.Timeout != 0 {
+ t := kernel.TaskFromContext(ctx)
+ start := t.Kernel().MonotonicClock().Now()
+ deadline := start.Add(v.Timeout)
+ t.BlockWithDeadline(ch, true, deadline)
+ }
}
// Read implements fs.FileOperations.Read.
@@ -803,7 +836,20 @@ func (s *socketOpsCommon) Bind(t *kernel.Task, sockaddr []byte) *syserr.Error {
}
// Issue the bind request to the endpoint.
- return syserr.TranslateNetstackError(s.Endpoint.Bind(addr))
+ err := s.Endpoint.Bind(addr)
+ if err == tcpip.ErrNoPortAvailable {
+ // Bind always returns EADDRINUSE irrespective of if the specified port was
+ // already bound or if an ephemeral port was requested but none were
+ // available.
+ //
+ // tcpip.ErrNoPortAvailable is mapped to EAGAIN in syserr package because
+ // UDP connect returns EAGAIN on ephemeral port exhaustion.
+ //
+ // TCP connect returns EADDRNOTAVAIL on ephemeral port exhaustion.
+ err = tcpip.ErrPortInUse
+ }
+
+ return syserr.TranslateNetstackError(err)
}
// Listen implements the linux syscall listen(2) for sockets backed by
@@ -814,7 +860,7 @@ func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
// blockingAccept implements a blocking version of accept(2), that is, if no
// connections are ready to be accept, it will block until one becomes ready.
-func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) {
+func (s *socketOpsCommon) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (tcpip.Endpoint, *waiter.Queue, *syserr.Error) {
// Register for notifications.
e, ch := waiter.NewChannelEntry(nil)
s.EventRegister(&e, waiter.EventIn)
@@ -823,7 +869,7 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waite
// Try to accept the connection again; if it fails, then wait until we
// get a notification.
for {
- if ep, wq, err := s.Endpoint.Accept(); err != tcpip.ErrWouldBlock {
+ if ep, wq, err := s.Endpoint.Accept(peerAddr); err != tcpip.ErrWouldBlock {
return ep, wq, syserr.TranslateNetstackError(err)
}
@@ -836,15 +882,18 @@ func (s *socketOpsCommon) blockingAccept(t *kernel.Task) (tcpip.Endpoint, *waite
// Accept implements the linux syscall accept(2) for sockets backed by
// tcpip.Endpoint.
func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
- // Issue the accept request to get the new endpoint.
- ep, wq, terr := s.Endpoint.Accept()
+ var peerAddr *tcpip.FullAddress
+ if peerRequested {
+ peerAddr = &tcpip.FullAddress{}
+ }
+ ep, wq, terr := s.Endpoint.Accept(peerAddr)
if terr != nil {
if terr != tcpip.ErrWouldBlock || !blocking {
return 0, nil, 0, syserr.TranslateNetstackError(terr)
}
var err *syserr.Error
- ep, wq, err = s.blockingAccept(t)
+ ep, wq, err = s.blockingAccept(t, peerAddr)
if err != nil {
return 0, nil, 0, err
}
@@ -864,13 +913,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
- if peerRequested {
- // Get address of the peer and write it to peer slice.
- var err *syserr.Error
- addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t)
- if err != nil {
- return 0, nil, 0, err
- }
+ if peerAddr != nil {
+ addr, addrLen = ConvertAddress(s.family, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
@@ -943,47 +987,12 @@ func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr us
return &val, nil
}
- if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
- switch name {
- case linux.IPT_SO_GET_INFO:
- if outLen < linux.SizeOfIPTGetinfo {
- return nil, syserr.ErrInvalidArgument
- }
-
- stack := inet.StackFromContext(t)
- if stack == nil {
- return nil, syserr.ErrNoDevice
- }
- info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr)
- if err != nil {
- return nil, err
- }
- return &info, nil
-
- case linux.IPT_SO_GET_ENTRIES:
- if outLen < linux.SizeOfIPTGetEntries {
- return nil, syserr.ErrInvalidArgument
- }
-
- stack := inet.StackFromContext(t)
- if stack == nil {
- return nil, syserr.ErrNoDevice
- }
- entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen)
- if err != nil {
- return nil, err
- }
- return &entries, nil
-
- }
- }
-
- return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen)
+ return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outPtr, outLen)
}
// GetSockOpt can be used to implement the linux syscall getsockopt(2) for
// sockets backed by a commonEndpoint.
-func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family int, skType linux.SockType, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
switch level {
case linux.SOL_SOCKET:
return getSockOptSocket(t, s, ep, family, skType, name, outLen)
@@ -992,10 +1001,10 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in
return getSockOptTCP(t, ep, name, outLen)
case linux.SOL_IPV6:
- return getSockOptIPv6(t, ep, name, outLen)
+ return getSockOptIPv6(t, s, ep, name, outPtr, outLen)
case linux.SOL_IP:
- return getSockOptIP(t, ep, name, outLen, family)
+ return getSockOptIP(t, s, ep, name, outPtr, outLen, family)
case linux.SOL_UDP,
linux.SOL_ICMPV6,
@@ -1025,7 +1034,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// Get the last error and convert it.
- err := ep.GetSockOpt(tcpip.ErrorOption{})
+ err := ep.LastError()
if err == nil {
optP := primitive.Int32(0)
return &optP, nil
@@ -1176,7 +1185,16 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
return nil, syserr.ErrInvalidArgument
}
- linger := linux.Linger{}
+ var v tcpip.LingerOption
+ var linger linux.Linger
+ if err := ep.GetSockOpt(&v); err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+
+ if v.Enabled {
+ linger.OnOff = 1
+ }
+ linger.Linger = int32(v.Timeout.Seconds())
return &linger, nil
case linux.SO_SNDTIMEO:
@@ -1390,8 +1408,12 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal
if err := ep.GetSockOpt(&v); err != nil {
return nil, syserr.TranslateNetstackError(err)
}
-
- lingerTimeout := primitive.Int32(time.Duration(v) / time.Second)
+ var lingerTimeout primitive.Int32
+ if v >= 0 {
+ lingerTimeout = primitive.Int32(time.Duration(v) / time.Second)
+ } else {
+ lingerTimeout = -1
+ }
return &lingerTimeout, nil
case linux.TCP_DEFER_ACCEPT:
@@ -1437,7 +1459,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal
}
// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
-func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IPV6_V6ONLY:
if outLen < sizeOfInt32 {
@@ -1490,10 +1512,50 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marsha
vP := primitive.Int32(boolToInt32(v))
return &vP, nil
- case linux.SO_ORIGINAL_DST:
+ case linux.IP6T_ORIGINAL_DST:
// TODO(gvisor.dev/issue/170): ip6tables.
return nil, syserr.ErrInvalidArgument
+ case linux.IP6T_SO_GET_INFO:
+ if outLen < linux.SizeOfIPTGetinfo {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Only valid for raw IPv6 sockets.
+ if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ return nil, syserr.ErrProtocolNotAvailable
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, true)
+ if err != nil {
+ return nil, err
+ }
+ return &info, nil
+
+ case linux.IP6T_SO_GET_ENTRIES:
+ // IPTGetEntries is reused for IPv6.
+ if outLen < linux.SizeOfIPTGetEntries {
+ return nil, syserr.ErrInvalidArgument
+ }
+ // Only valid for raw IPv6 sockets.
+ if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ return nil, syserr.ErrProtocolNotAvailable
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ entries, err := netfilter.GetEntries6(t, stack.(*Stack).Stack, outPtr, outLen)
+ if err != nil {
+ return nil, err
+ }
+ return &entries, nil
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -1501,7 +1563,7 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marsha
}
// getSockOptIP implements GetSockOpt when level is SOL_IP.
-func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int, family int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IP_TTL:
if outLen < sizeOfInt32 {
@@ -1617,6 +1679,46 @@ func getSockOptIP(t *kernel.Task, ep commonEndpoint, name, outLen int, family in
a, _ := ConvertAddress(linux.AF_INET, tcpip.FullAddress(v))
return a.(*linux.SockAddrInet), nil
+ case linux.IPT_SO_GET_INFO:
+ if outLen < linux.SizeOfIPTGetinfo {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Only valid for raw IPv4 sockets.
+ if family, skType, _ := s.Type(); family != linux.AF_INET || skType != linux.SOCK_RAW {
+ return nil, syserr.ErrProtocolNotAvailable
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, false)
+ if err != nil {
+ return nil, err
+ }
+ return &info, nil
+
+ case linux.IPT_SO_GET_ENTRIES:
+ if outLen < linux.SizeOfIPTGetEntries {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Only valid for raw IPv4 sockets.
+ if family, skType, _ := s.Type(); family != linux.AF_INET || skType != linux.SOCK_RAW {
+ return nil, syserr.ErrProtocolNotAvailable
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ entries, err := netfilter.GetEntries4(t, stack.(*Stack).Stack, outPtr, outLen)
+ if err != nil {
+ return nil, err
+ }
+ return &entries, nil
+
default:
emitUnimplementedEventIP(t, name)
}
@@ -1650,26 +1752,6 @@ func (s *SocketOperations) SetSockOpt(t *kernel.Task, level int, name int, optVa
return nil
}
- if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
- switch name {
- case linux.IPT_SO_SET_REPLACE:
- if len(optVal) < linux.SizeOfIPTReplace {
- return syserr.ErrInvalidArgument
- }
-
- stack := inet.StackFromContext(t)
- if stack == nil {
- return syserr.ErrNoDevice
- }
- // Stack must be a netstack stack.
- return netfilter.SetEntries(stack.(*Stack).Stack, optVal)
-
- case linux.IPT_SO_SET_ADD_COUNTERS:
- // TODO(gvisor.dev/issue/170): Counter support.
- return nil
- }
- }
-
return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
}
@@ -1684,21 +1766,26 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int
return setSockOptTCP(t, ep, name, optVal)
case linux.SOL_IPV6:
- return setSockOptIPv6(t, ep, name, optVal)
+ return setSockOptIPv6(t, s, ep, name, optVal)
case linux.SOL_IP:
- return setSockOptIP(t, ep, name, optVal)
+ return setSockOptIP(t, s, ep, name, optVal)
+
+ case linux.SOL_PACKET:
+ // gVisor doesn't support any SOL_PACKET options just return not
+ // supported. Returning nil here will result in tcpdump thinking AF_PACKET
+ // features are supported and proceed to use them and break.
+ t.Kernel().EmitUnimplementedEvent(t)
+ return syserr.ErrProtocolNotAvailable
case linux.SOL_UDP,
linux.SOL_ICMPV6,
- linux.SOL_RAW,
- linux.SOL_PACKET:
+ linux.SOL_RAW:
t.Kernel().EmitUnimplementedEvent(t)
}
- // Default to the old behavior; hand off to network stack.
- return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+ return nil
}
// setSockOptSocket implements SetSockOpt when level is SOL_SOCKET.
@@ -1743,7 +1830,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
name := string(optVal[:n])
if name == "" {
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(0)))
+ v := tcpip.BindToDeviceOption(0)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&v))
}
s := t.NetworkContext()
if s == nil {
@@ -1751,7 +1839,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
}
for nicID, nic := range s.Interfaces() {
if nic.Name == name {
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.BindToDeviceOption(nicID)))
+ v := tcpip.BindToDeviceOption(nicID)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&v))
}
}
return syserr.ErrUnknownDevice
@@ -1817,7 +1906,8 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.OutOfBandInlineOption(v)))
+ opt := tcpip.OutOfBandInlineOption(v)
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&opt))
case linux.SO_NO_CHECK:
if len(optVal) < sizeOfInt32 {
@@ -1839,19 +1929,21 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- return nil
+ return syserr.TranslateNetstackError(
+ ep.SetSockOpt(&tcpip.LingerOption{
+ Enabled: v.OnOff != 0,
+ Timeout: time.Second * time.Duration(v.Linger)}))
case linux.SO_DETACH_FILTER:
// optval is ignored.
var v tcpip.SocketDetachFilterOption
- return syserr.TranslateNetstackError(ep.SetSockOpt(v))
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&v))
default:
socket.SetSockOptEmitUnimplementedEvent(t, name)
}
- // Default to the old behavior; hand off to network stack.
- return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+ return nil
}
// setSockOptTCP implements SetSockOpt when level is SOL_TCP.
@@ -1898,7 +1990,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
if v < 1 || v > linux.MAX_TCP_KEEPIDLE {
return syserr.ErrInvalidArgument
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIdleOption(time.Second * time.Duration(v))))
+ opt := tcpip.KeepaliveIdleOption(time.Second * time.Duration(v))
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&opt))
case linux.TCP_KEEPINTVL:
if len(optVal) < sizeOfInt32 {
@@ -1909,7 +2002,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
if v < 1 || v > linux.MAX_TCP_KEEPINTVL {
return syserr.ErrInvalidArgument
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v))))
+ opt := tcpip.KeepaliveIntervalOption(time.Second * time.Duration(v))
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&opt))
case linux.TCP_KEEPCNT:
if len(optVal) < sizeOfInt32 {
@@ -1931,11 +2025,12 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
if v < 0 {
return syserr.ErrInvalidArgument
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v))))
+ opt := tcpip.TCPUserTimeoutOption(time.Millisecond * time.Duration(v))
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&opt))
case linux.TCP_CONGESTION:
v := tcpip.CongestionControlOption(optVal)
- if err := ep.SetSockOpt(v); err != nil {
+ if err := ep.SetSockOpt(&v); err != nil {
return syserr.TranslateNetstackError(err)
}
return nil
@@ -1945,8 +2040,9 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
return syserr.ErrInvalidArgument
}
- v := usermem.ByteOrder.Uint32(optVal)
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v))))
+ v := int32(usermem.ByteOrder.Uint32(optVal))
+ opt := tcpip.TCPLingerTimeoutOption(time.Second * time.Duration(v))
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&opt))
case linux.TCP_DEFER_ACCEPT:
if len(optVal) < sizeOfInt32 {
@@ -1956,7 +2052,8 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
if v < 0 {
v = 0
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v))))
+ opt := tcpip.TCPDeferAcceptOption(time.Second * time.Duration(v))
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&opt))
case linux.TCP_SYNCNT:
if len(optVal) < sizeOfInt32 {
@@ -1981,12 +2078,11 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
emitUnimplementedEventTCP(t, name)
}
- // Default to the old behavior; hand off to network stack.
- return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+ return nil
}
// setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6.
-func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
switch name {
case linux.IPV6_V6ONLY:
if len(optVal) < sizeOfInt32 {
@@ -2035,12 +2131,32 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0))
+ case linux.IP6T_SO_SET_REPLACE:
+ if len(optVal) < linux.SizeOfIP6TReplace {
+ return syserr.ErrInvalidArgument
+ }
+
+ // Only valid for raw IPv6 sockets.
+ if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ return syserr.ErrProtocolNotAvailable
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return syserr.ErrNoDevice
+ }
+ // Stack must be a netstack stack.
+ return netfilter.SetEntries(stack.(*Stack).Stack, optVal, true)
+
+ case linux.IP6T_SO_SET_ADD_COUNTERS:
+ // TODO(gvisor.dev/issue/170): Counter support.
+ return nil
+
default:
emitUnimplementedEventIPv6(t, name)
}
- // Default to the old behavior; hand off to network stack.
- return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+ return nil
}
var (
@@ -2095,7 +2211,7 @@ func parseIntOrChar(buf []byte) (int32, *syserr.Error) {
}
// setSockOptIP implements SetSockOpt when level is SOL_IP.
-func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
switch name {
case linux.IP_MULTICAST_TTL:
v, err := parseIntOrChar(optVal)
@@ -2118,7 +2234,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.AddMembershipOption{
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.AddMembershipOption{
NIC: tcpip.NICID(req.InterfaceIndex),
// TODO(igudger): Change AddMembership to use the standard
// any address representation.
@@ -2132,7 +2248,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.RemoveMembershipOption{
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.RemoveMembershipOption{
NIC: tcpip.NICID(req.InterfaceIndex),
// TODO(igudger): Change DropMembership to use the standard
// any address representation.
@@ -2146,7 +2262,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
return err
}
- return syserr.TranslateNetstackError(ep.SetSockOpt(tcpip.MulticastInterfaceOption{
+ return syserr.TranslateNetstackError(ep.SetSockOpt(&tcpip.MulticastInterfaceOption{
NIC: tcpip.NICID(req.InterfaceIndex),
InterfaceAddr: bytesToIPAddress(req.InterfaceAddr[:]),
}))
@@ -2215,6 +2331,27 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
}
return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.IPHdrIncludedOption, v != 0))
+ case linux.IPT_SO_SET_REPLACE:
+ if len(optVal) < linux.SizeOfIPTReplace {
+ return syserr.ErrInvalidArgument
+ }
+
+ // Only valid for raw IPv4 sockets.
+ if family, skType, _ := s.Type(); family != linux.AF_INET || skType != linux.SOCK_RAW {
+ return syserr.ErrProtocolNotAvailable
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return syserr.ErrNoDevice
+ }
+ // Stack must be a netstack stack.
+ return netfilter.SetEntries(stack.(*Stack).Stack, optVal, false)
+
+ case linux.IPT_SO_SET_ADD_COUNTERS:
+ // TODO(gvisor.dev/issue/170): Counter support.
+ return nil
+
case linux.IP_ADD_SOURCE_MEMBERSHIP,
linux.IP_BIND_ADDRESS_NO_PORT,
linux.IP_BLOCK_SOURCE,
@@ -2249,8 +2386,7 @@ func setSockOptIP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *s
t.Kernel().EmitUnimplementedEvent(t)
}
- // Default to the old behavior; hand off to network stack.
- return syserr.TranslateNetstackError(ep.SetSockOpt(struct{}{}))
+ return nil
}
// emitUnimplementedEventTCP emits unimplemented event if name is valid. This
diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go
index 3335e7430..c0212ad76 100644
--- a/pkg/sentry/socket/netstack/netstack_vfs2.go
+++ b/pkg/sentry/socket/netstack/netstack_vfs2.go
@@ -18,21 +18,19 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/amutex"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
- "gvisor.dev/gvisor/pkg/sentry/inet"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/socket"
- "gvisor.dev/gvisor/pkg/sentry/socket/netfilter"
"gvisor.dev/gvisor/pkg/sentry/vfs"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/syserror"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
- "gvisor.dev/gvisor/tools/go_marshal/primitive"
)
// SocketVFS2 encapsulates all the state needed to represent a network stack
@@ -58,6 +56,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu
mnt := t.Kernel().SocketMount()
d := sockfs.NewDentry(t.Credentials(), mnt)
+ defer d.DecRef(t)
s := &SocketVFS2{
socketOpsCommon: socketOpsCommon{
@@ -152,14 +151,18 @@ func (s *SocketVFS2) Write(ctx context.Context, src usermem.IOSequence, opts vfs
// tcpip.Endpoint.
func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
// Issue the accept request to get the new endpoint.
- ep, wq, terr := s.Endpoint.Accept()
+ var peerAddr *tcpip.FullAddress
+ if peerRequested {
+ peerAddr = &tcpip.FullAddress{}
+ }
+ ep, wq, terr := s.Endpoint.Accept(peerAddr)
if terr != nil {
if terr != tcpip.ErrWouldBlock || !blocking {
return 0, nil, 0, syserr.TranslateNetstackError(terr)
}
var err *syserr.Error
- ep, wq, err = s.blockingAccept(t)
+ ep, wq, err = s.blockingAccept(t, peerAddr)
if err != nil {
return 0, nil, 0, err
}
@@ -177,13 +180,9 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addr linux.SockAddr
var addrLen uint32
- if peerRequested {
+ if peerAddr != nil {
// Get address of the peer and write it to peer slice.
- var err *syserr.Error
- addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t)
- if err != nil {
- return 0, nil, 0, err
- }
+ addr, addrLen = ConvertAddress(s.family, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{
@@ -233,42 +232,7 @@ func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.
return &val, nil
}
- if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
- switch name {
- case linux.IPT_SO_GET_INFO:
- if outLen < linux.SizeOfIPTGetinfo {
- return nil, syserr.ErrInvalidArgument
- }
-
- stack := inet.StackFromContext(t)
- if stack == nil {
- return nil, syserr.ErrNoDevice
- }
- info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr)
- if err != nil {
- return nil, err
- }
- return &info, nil
-
- case linux.IPT_SO_GET_ENTRIES:
- if outLen < linux.SizeOfIPTGetEntries {
- return nil, syserr.ErrInvalidArgument
- }
-
- stack := inet.StackFromContext(t)
- if stack == nil {
- return nil, syserr.ErrNoDevice
- }
- entries, err := netfilter.GetEntries(t, stack.(*Stack).Stack, outPtr, outLen)
- if err != nil {
- return nil, err
- }
- return &entries, nil
-
- }
- }
-
- return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outLen)
+ return GetSockOpt(t, s, s.Endpoint, s.family, s.skType, level, name, outPtr, outLen)
}
// SetSockOpt implements the linux syscall setsockopt(2) for sockets backed by
@@ -298,26 +262,6 @@ func (s *SocketVFS2) SetSockOpt(t *kernel.Task, level int, name int, optVal []by
return nil
}
- if s.skType == linux.SOCK_RAW && level == linux.IPPROTO_IP {
- switch name {
- case linux.IPT_SO_SET_REPLACE:
- if len(optVal) < linux.SizeOfIPTReplace {
- return syserr.ErrInvalidArgument
- }
-
- stack := inet.StackFromContext(t)
- if stack == nil {
- return syserr.ErrNoDevice
- }
- // Stack must be a netstack stack.
- return netfilter.SetEntries(stack.(*Stack).Stack, optVal)
-
- case linux.IPT_SO_SET_ADD_COUNTERS:
- // TODO(gvisor.dev/issue/170): Counter support.
- return nil
- }
- }
-
return SetSockOpt(t, s, s.Endpoint, level, name, optVal)
}
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index f9097d6b2..1028d2a6e 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -155,7 +155,7 @@ func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error {
// TCPReceiveBufferSize implements inet.Stack.TCPReceiveBufferSize.
func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
- var rs tcp.ReceiveBufferSizeOption
+ var rs tcpip.TCPReceiveBufferSizeRangeOption
err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &rs)
return inet.TCPBufferSize{
Min: rs.Min,
@@ -166,17 +166,17 @@ func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
- rs := tcp.ReceiveBufferSizeOption{
+ rs := tcpip.TCPReceiveBufferSizeRangeOption{
Min: size.Min,
Default: size.Default,
Max: size.Max,
}
- return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, rs)).ToError()
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &rs)).ToError()
}
// TCPSendBufferSize implements inet.Stack.TCPSendBufferSize.
func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
- var ss tcp.SendBufferSizeOption
+ var ss tcpip.TCPSendBufferSizeRangeOption
err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &ss)
return inet.TCPBufferSize{
Min: ss.Min,
@@ -187,29 +187,30 @@ func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
- ss := tcp.SendBufferSizeOption{
+ ss := tcpip.TCPSendBufferSizeRangeOption{
Min: size.Min,
Default: size.Default,
Max: size.Max,
}
- return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, ss)).ToError()
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &ss)).ToError()
}
// TCPSACKEnabled implements inet.Stack.TCPSACKEnabled.
func (s *Stack) TCPSACKEnabled() (bool, error) {
- var sack tcp.SACKEnabled
+ var sack tcpip.TCPSACKEnabled
err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &sack)
return bool(sack), syserr.TranslateNetstackError(err).ToError()
}
// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
func (s *Stack) SetTCPSACKEnabled(enabled bool) error {
- return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.SACKEnabled(enabled))).ToError()
+ opt := tcpip.TCPSACKEnabled(enabled)
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)).ToError()
}
// TCPRecovery implements inet.Stack.TCPRecovery.
func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) {
- var recovery tcp.Recovery
+ var recovery tcpip.TCPRecovery
if err := s.Stack.TransportProtocolOption(tcp.ProtocolNumber, &recovery); err != nil {
return 0, syserr.TranslateNetstackError(err).ToError()
}
@@ -218,7 +219,8 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) {
// SetTCPRecovery implements inet.Stack.SetTCPRecovery.
func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error {
- return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, tcp.Recovery(recovery))).ToError()
+ opt := tcpip.TCPRecovery(recovery)
+ return syserr.TranslateNetstackError(s.Stack.SetTransportProtocolOption(tcp.ProtocolNumber, &opt)).ToError()
}
// Statistics implements inet.Stack.Statistics.
@@ -417,8 +419,7 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool {
case ipv4.ProtocolNumber, ipv6.ProtocolNumber:
return s.Stack.Forwarding(protocol)
default:
- log.Warningf("Forwarding(%v) failed: unsupported protocol", protocol)
- return false
+ panic(fmt.Sprintf("Forwarding(%v) failed: unsupported protocol", protocol))
}
}
@@ -428,8 +429,7 @@ func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool)
case ipv4.ProtocolNumber, ipv6.ProtocolNumber:
s.Stack.SetForwarding(protocol, enable)
default:
- log.Warningf("SetForwarding(%v) failed: unsupported protocol", protocol)
- return syserr.ErrProtocolNotSupported.ToError()
+ panic(fmt.Sprintf("SetForwarding(%v) failed: unsupported protocol", protocol))
}
return nil
}
diff --git a/pkg/sentry/socket/socket.go b/pkg/sentry/socket/socket.go
index 04b259d27..fd31479e5 100644
--- a/pkg/sentry/socket/socket.go
+++ b/pkg/sentry/socket/socket.go
@@ -25,6 +25,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/binary"
"gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/device"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -35,7 +36,6 @@ import (
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// ControlMessages represents the union of unix control messages and tcpip
diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD
index cb953e4dc..a89583dad 100644
--- a/pkg/sentry/socket/unix/BUILD
+++ b/pkg/sentry/socket/unix/BUILD
@@ -29,6 +29,7 @@ go_library(
"//pkg/context",
"//pkg/fspath",
"//pkg/log",
+ "//pkg/marshal",
"//pkg/refs",
"//pkg/safemem",
"//pkg/sentry/arch",
@@ -49,6 +50,5 @@ go_library(
"//pkg/tcpip",
"//pkg/usermem",
"//pkg/waiter",
- "//tools/go_marshal/marshal",
],
)
diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD
index c708b6030..26c3a51b9 100644
--- a/pkg/sentry/socket/unix/transport/BUILD
+++ b/pkg/sentry/socket/unix/transport/BUILD
@@ -15,6 +15,17 @@ go_template_instance(
},
)
+go_template_instance(
+ name = "queue_refs",
+ out = "queue_refs.go",
+ package = "transport",
+ prefix = "queue",
+ template = "//pkg/refs_vfs2:refs_template",
+ types = {
+ "T": "queue",
+ },
+)
+
go_library(
name = "transport",
srcs = [
@@ -22,6 +33,7 @@ go_library(
"connectioned_state.go",
"connectionless.go",
"queue.go",
+ "queue_refs.go",
"transport_message_list.go",
"unix.go",
],
diff --git a/pkg/sentry/socket/unix/transport/connectioned.go b/pkg/sentry/socket/unix/transport/connectioned.go
index c67b602f0..aa4f3c04d 100644
--- a/pkg/sentry/socket/unix/transport/connectioned.go
+++ b/pkg/sentry/socket/unix/transport/connectioned.go
@@ -142,9 +142,9 @@ func NewPair(ctx context.Context, stype linux.SockType, uid UniqueIDProvider) (E
}
q1 := &queue{ReaderQueue: a.Queue, WriterQueue: b.Queue, limit: initialLimit}
- q1.EnableLeakCheck("transport.queue")
+ q1.EnableLeakCheck()
q2 := &queue{ReaderQueue: b.Queue, WriterQueue: a.Queue, limit: initialLimit}
- q2.EnableLeakCheck("transport.queue")
+ q2.EnableLeakCheck()
if stype == linux.SOCK_STREAM {
a.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{q1}}
@@ -300,14 +300,14 @@ func (e *connectionedEndpoint) BidirectionalConnect(ctx context.Context, ce Conn
}
readQueue := &queue{ReaderQueue: ce.WaiterQueue(), WriterQueue: ne.Queue, limit: initialLimit}
- readQueue.EnableLeakCheck("transport.queue")
+ readQueue.EnableLeakCheck()
ne.connected = &connectedEndpoint{
endpoint: ce,
writeQueue: readQueue,
}
writeQueue := &queue{ReaderQueue: ne.Queue, WriterQueue: ce.WaiterQueue(), limit: initialLimit}
- writeQueue.EnableLeakCheck("transport.queue")
+ writeQueue.EnableLeakCheck()
if e.stype == linux.SOCK_STREAM {
ne.receiver = &streamQueueReceiver{queueReceiver: queueReceiver{readQueue: writeQueue}}
} else {
@@ -391,7 +391,7 @@ func (e *connectionedEndpoint) Listen(backlog int) *syserr.Error {
}
// Accept accepts a new connection.
-func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) {
+func (e *connectionedEndpoint) Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error) {
e.Lock()
defer e.Unlock()
@@ -401,6 +401,18 @@ func (e *connectionedEndpoint) Accept() (Endpoint, *syserr.Error) {
select {
case ne := <-e.acceptedChan:
+ if peerAddr != nil {
+ ne.Lock()
+ c := ne.connected
+ ne.Unlock()
+ if c != nil {
+ addr, err := c.GetLocalAddress()
+ if err != nil {
+ return nil, syserr.TranslateNetstackError(err)
+ }
+ *peerAddr = addr
+ }
+ }
return ne, nil
default:
diff --git a/pkg/sentry/socket/unix/transport/connectionless.go b/pkg/sentry/socket/unix/transport/connectionless.go
index 70ee8f9b8..f8aacca13 100644
--- a/pkg/sentry/socket/unix/transport/connectionless.go
+++ b/pkg/sentry/socket/unix/transport/connectionless.go
@@ -42,7 +42,7 @@ var (
func NewConnectionless(ctx context.Context) Endpoint {
ep := &connectionlessEndpoint{baseEndpoint{Queue: &waiter.Queue{}}}
q := queue{ReaderQueue: ep.Queue, WriterQueue: &waiter.Queue{}, limit: initialLimit}
- q.EnableLeakCheck("transport.queue")
+ q.EnableLeakCheck()
ep.receiver = &queueReceiver{readQueue: &q}
return ep
}
@@ -144,12 +144,12 @@ func (e *connectionlessEndpoint) Connect(ctx context.Context, server BoundEndpoi
}
// Listen starts listening on the connection.
-func (e *connectionlessEndpoint) Listen(int) *syserr.Error {
+func (*connectionlessEndpoint) Listen(int) *syserr.Error {
return syserr.ErrNotSupported
}
// Accept accepts a new connection.
-func (e *connectionlessEndpoint) Accept() (Endpoint, *syserr.Error) {
+func (*connectionlessEndpoint) Accept(*tcpip.FullAddress) (Endpoint, *syserr.Error) {
return nil, syserr.ErrNotSupported
}
diff --git a/pkg/sentry/socket/unix/transport/queue.go b/pkg/sentry/socket/unix/transport/queue.go
index ef6043e19..342def28f 100644
--- a/pkg/sentry/socket/unix/transport/queue.go
+++ b/pkg/sentry/socket/unix/transport/queue.go
@@ -16,7 +16,6 @@ package transport
import (
"gvisor.dev/gvisor/pkg/context"
- "gvisor.dev/gvisor/pkg/refs"
"gvisor.dev/gvisor/pkg/sync"
"gvisor.dev/gvisor/pkg/syserr"
"gvisor.dev/gvisor/pkg/tcpip"
@@ -28,7 +27,7 @@ import (
//
// +stateify savable
type queue struct {
- refs.AtomicRefCount
+ queueRefs
ReaderQueue *waiter.Queue
WriterQueue *waiter.Queue
@@ -68,11 +67,13 @@ func (q *queue) Reset(ctx context.Context) {
q.mu.Unlock()
}
-// DecRef implements RefCounter.DecRef with destructor q.Reset.
+// DecRef implements RefCounter.DecRef.
func (q *queue) DecRef(ctx context.Context) {
- q.DecRefWithDestructor(ctx, q.Reset)
- // We don't need to notify after resetting because no one cares about
- // this queue after all references have been dropped.
+ q.queueRefs.DecRef(func() {
+ // We don't need to notify after resetting because no one cares about
+ // this queue after all references have been dropped.
+ q.Reset(ctx)
+ })
}
// IsReadable determines if q is currently readable.
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index 475d7177e..d6fc03520 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -151,7 +151,10 @@ type Endpoint interface {
// block if no new connections are available.
//
// The returned Queue is the wait queue for the newly created endpoint.
- Accept() (Endpoint, *syserr.Error)
+ //
+ // peerAddr if not nil will be populated with the address of the connected
+ // peer on a successful accept.
+ Accept(peerAddr *tcpip.FullAddress) (Endpoint, *syserr.Error)
// Bind binds the endpoint to a specific local address and port.
// Specifying a NIC is optional.
@@ -172,9 +175,8 @@ type Endpoint interface {
// connected.
GetRemoteAddress() (tcpip.FullAddress, *tcpip.Error)
- // SetSockOpt sets a socket option. opt should be one of the tcpip.*Option
- // types.
- SetSockOpt(opt interface{}) *tcpip.Error
+ // SetSockOpt sets a socket option.
+ SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error
// SetSockOptBool sets a socket option for simple cases when a value has
// the int type.
@@ -184,9 +186,8 @@ type Endpoint interface {
// the int type.
SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error
- // GetSockOpt gets a socket option. opt should be a pointer to one of the
- // tcpip.*Option types.
- GetSockOpt(opt interface{}) *tcpip.Error
+ // GetSockOpt gets a socket option.
+ GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error
// GetSockOptBool gets a socket option for simple cases when a return
// value has the int type.
@@ -199,6 +200,9 @@ type Endpoint interface {
// State returns the current state of the socket, as represented by Linux in
// procfs.
State() uint32
+
+ // LastError implements tcpip.Endpoint.LastError.
+ LastError() *tcpip.Error
}
// A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket
@@ -742,6 +746,9 @@ type baseEndpoint struct {
// path is not empty if the endpoint has been bound,
// or may be used if the endpoint is connected.
path string
+
+ // linger is used for SO_LINGER socket option.
+ linger tcpip.LingerOption
}
// EventRegister implements waiter.Waitable.EventRegister.
@@ -837,8 +844,14 @@ func (e *baseEndpoint) SendMsg(ctx context.Context, data [][]byte, c ControlMess
return n, err
}
-// SetSockOpt sets a socket option. Currently not supported.
-func (e *baseEndpoint) SetSockOpt(opt interface{}) *tcpip.Error {
+// SetSockOpt sets a socket option.
+func (e *baseEndpoint) SetSockOpt(opt tcpip.SettableSocketOption) *tcpip.Error {
+ switch v := opt.(type) {
+ case *tcpip.LingerOption:
+ e.Lock()
+ e.linger = *v
+ e.Unlock()
+ }
return nil
}
@@ -940,9 +953,12 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch opt.(type) {
- case tcpip.ErrorOption:
+func (e *baseEndpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error {
+ switch o := opt.(type) {
+ case *tcpip.LingerOption:
+ e.Lock()
+ *o = e.linger
+ e.Unlock()
return nil
default:
@@ -951,6 +967,11 @@ func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
}
+// LastError implements Endpoint.LastError.
+func (*baseEndpoint) LastError() *tcpip.Error {
+ return nil
+}
+
// Shutdown closes the read and/or write end of the endpoint connection to its
// peer.
func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *syserr.Error {
diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go
index b7e8e4325..917055cea 100644
--- a/pkg/sentry/socket/unix/unix.go
+++ b/pkg/sentry/socket/unix/unix.go
@@ -24,6 +24,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/fs/fsutil"
@@ -39,7 +40,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketOperations is a Unix socket. It is similar to a netstack socket,
@@ -194,7 +194,7 @@ func (s *SocketOperations) Ioctl(ctx context.Context, _ *fs.File, io usermem.IO,
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketOperations) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
- return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
+ return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outPtr, outLen)
}
// Listen implements the linux syscall listen(2) for sockets backed by
@@ -205,7 +205,7 @@ func (s *socketOpsCommon) Listen(t *kernel.Task, backlog int) *syserr.Error {
// blockingAccept implements a blocking version of accept(2), that is, if no
// connections are ready to be accept, it will block until one becomes ready.
-func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+func (s *SocketOperations) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (transport.Endpoint, *syserr.Error) {
// Register for notifications.
e, ch := waiter.NewChannelEntry(nil)
s.EventRegister(&e, waiter.EventIn)
@@ -214,7 +214,7 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *
// Try to accept the connection; if it fails, then wait until we get a
// notification.
for {
- if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ if ep, err := s.ep.Accept(peerAddr); err != syserr.ErrWouldBlock {
return ep, err
}
@@ -227,15 +227,18 @@ func (s *SocketOperations) blockingAccept(t *kernel.Task) (transport.Endpoint, *
// Accept implements the linux syscall accept(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
- // Issue the accept request to get the new endpoint.
- ep, err := s.ep.Accept()
+ var peerAddr *tcpip.FullAddress
+ if peerRequested {
+ peerAddr = &tcpip.FullAddress{}
+ }
+ ep, err := s.ep.Accept(peerAddr)
if err != nil {
if err != syserr.ErrWouldBlock || !blocking {
return 0, nil, 0, err
}
var err *syserr.Error
- ep, err = s.blockingAccept(t)
+ ep, err = s.blockingAccept(t, peerAddr)
if err != nil {
return 0, nil, 0, err
}
@@ -252,13 +255,8 @@ func (s *SocketOperations) Accept(t *kernel.Task, peerRequested bool, flags int,
var addr linux.SockAddr
var addrLen uint32
- if peerRequested {
- // Get address of the peer.
- var err *syserr.Error
- addr, addrLen, err = ns.FileOperations.(*SocketOperations).GetPeerName(t)
- if err != nil {
- return 0, nil, 0, err
- }
+ if peerAddr != nil {
+ addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFrom(0, ns, kernel.FDFlags{
diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go
index d066ef8ab..3688f22d2 100644
--- a/pkg/sentry/socket/unix/unix_vfs2.go
+++ b/pkg/sentry/socket/unix/unix_vfs2.go
@@ -18,6 +18,7 @@ import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/fspath"
+ "gvisor.dev/gvisor/pkg/marshal"
"gvisor.dev/gvisor/pkg/sentry/arch"
fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock"
"gvisor.dev/gvisor/pkg/sentry/fsimpl/sockfs"
@@ -32,7 +33,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/usermem"
"gvisor.dev/gvisor/pkg/waiter"
- "gvisor.dev/gvisor/tools/go_marshal/marshal"
)
// SocketVFS2 implements socket.SocketVFS2 (and by extension,
@@ -91,12 +91,12 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3
// GetSockOpt implements the linux syscall getsockopt(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketVFS2) GetSockOpt(t *kernel.Task, level, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
- return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outLen)
+ return netstack.GetSockOpt(t, s, s.ep, linux.AF_UNIX, s.ep.Type(), level, name, outPtr, outLen)
}
// blockingAccept implements a blocking version of accept(2), that is, if no
// connections are ready to be accept, it will block until one becomes ready.
-func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr.Error) {
+func (s *SocketVFS2) blockingAccept(t *kernel.Task, peerAddr *tcpip.FullAddress) (transport.Endpoint, *syserr.Error) {
// Register for notifications.
e, ch := waiter.NewChannelEntry(nil)
s.socketOpsCommon.EventRegister(&e, waiter.EventIn)
@@ -105,7 +105,7 @@ func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr
// Try to accept the connection; if it fails, then wait until we get a
// notification.
for {
- if ep, err := s.ep.Accept(); err != syserr.ErrWouldBlock {
+ if ep, err := s.ep.Accept(peerAddr); err != syserr.ErrWouldBlock {
return ep, err
}
@@ -118,15 +118,18 @@ func (s *SocketVFS2) blockingAccept(t *kernel.Task) (transport.Endpoint, *syserr
// Accept implements the linux syscall accept(2) for sockets backed by
// a transport.Endpoint.
func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, blocking bool) (int32, linux.SockAddr, uint32, *syserr.Error) {
- // Issue the accept request to get the new endpoint.
- ep, err := s.ep.Accept()
+ var peerAddr *tcpip.FullAddress
+ if peerRequested {
+ peerAddr = &tcpip.FullAddress{}
+ }
+ ep, err := s.ep.Accept(peerAddr)
if err != nil {
if err != syserr.ErrWouldBlock || !blocking {
return 0, nil, 0, err
}
var err *syserr.Error
- ep, err = s.blockingAccept(t)
+ ep, err = s.blockingAccept(t, peerAddr)
if err != nil {
return 0, nil, 0, err
}
@@ -144,13 +147,8 @@ func (s *SocketVFS2) Accept(t *kernel.Task, peerRequested bool, flags int, block
var addr linux.SockAddr
var addrLen uint32
- if peerRequested {
- // Get address of the peer.
- var err *syserr.Error
- addr, addrLen, err = ns.Impl().(*SocketVFS2).GetPeerName(t)
- if err != nil {
- return 0, nil, 0, err
- }
+ if peerAddr != nil {
+ addr, addrLen = netstack.ConvertAddress(linux.AF_UNIX, *peerAddr)
}
fd, e := t.NewFDFromVFS2(0, ns, kernel.FDFlags{