summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/abi/linux/netfilter_ipv6.go13
-rw-r--r--pkg/refs/refcounter.go33
-rw-r--r--pkg/sentry/platform/kvm/virtual_map.go2
-rw-r--r--pkg/sentry/socket/netfilter/BUILD1
-rw-r--r--pkg/sentry/socket/netfilter/ipv4.go33
-rw-r--r--pkg/sentry/socket/netfilter/ipv6.go265
-rw-r--r--pkg/sentry/socket/netfilter/netfilter.go77
-rw-r--r--pkg/sentry/socket/netfilter/targets.go10
-rw-r--r--pkg/sentry/socket/netstack/netstack.go80
-rw-r--r--pkg/sentry/socket/unix/transport/unix.go10
-rw-r--r--pkg/sentry/watchdog/watchdog.go28
-rw-r--r--pkg/tcpip/adapters/gonet/gonet.go2
-rw-r--r--pkg/tcpip/adapters/gonet/gonet_test.go2
-rw-r--r--pkg/tcpip/sample/tun_tcp_connect/main.go2
-rw-r--r--pkg/tcpip/stack/iptables.go12
-rw-r--r--pkg/tcpip/stack/iptables_types.go5
-rw-r--r--pkg/tcpip/stack/transport_test.go26
-rw-r--r--pkg/tcpip/tcpip.go7
-rw-r--r--pkg/tcpip/transport/icmp/endpoint.go15
-rw-r--r--pkg/tcpip/transport/packet/endpoint.go8
-rw-r--r--pkg/tcpip/transport/raw/endpoint.go14
-rw-r--r--pkg/tcpip/transport/tcp/connect.go4
-rw-r--r--pkg/tcpip/transport/tcp/dual_stack_test.go6
-rw-r--r--pkg/tcpip/transport/tcp/endpoint.go5
-rw-r--r--pkg/tcpip/transport/tcp/tcp_test.go10
-rw-r--r--pkg/tcpip/transport/tcp/testing/context/context.go5
-rw-r--r--pkg/tcpip/transport/udp/endpoint.go8
-rw-r--r--pkg/test/testutil/testutil.go31
28 files changed, 575 insertions, 139 deletions
diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go
index 9bb9efb10..f6117024c 100644
--- a/pkg/abi/linux/netfilter_ipv6.go
+++ b/pkg/abi/linux/netfilter_ipv6.go
@@ -290,6 +290,19 @@ type IP6TIP struct {
const SizeOfIP6TIP = 136
+// Flags in IP6TIP.Flags. Corresponding constants are in
+// include/uapi/linux/netfilter_ipv6/ip6_tables.h.
+const (
+ // Whether to check the Protocol field.
+ IP6T_F_PROTO = 0x01
+ // Whether to match the TOS field.
+ IP6T_F_TOS = 0x02
+ // Indicates that the jump target is an aboslute GOTO, not an offset.
+ IP6T_F_GOTO = 0x04
+ // Enables all flags.
+ IP6T_F_MASK = 0x07
+)
+
// Flags in IP6TIP.InverseFlags. Corresponding constants are in
// include/uapi/linux/netfilter_ipv6/ip6_tables.h.
const (
diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go
index d9d5e6bcb..57d8542b9 100644
--- a/pkg/refs/refcounter.go
+++ b/pkg/refs/refcounter.go
@@ -234,6 +234,39 @@ const (
LeaksLogTraces
)
+// Set implements flag.Value.
+func (l *LeakMode) Set(v string) error {
+ switch v {
+ case "disabled":
+ *l = NoLeakChecking
+ case "log-names":
+ *l = LeaksLogWarning
+ case "log-traces":
+ *l = LeaksLogTraces
+ default:
+ return fmt.Errorf("invalid ref leak mode %q", v)
+ }
+ return nil
+}
+
+// Get implements flag.Value.
+func (l *LeakMode) Get() interface{} {
+ return *l
+}
+
+// String implements flag.Value.
+func (l *LeakMode) String() string {
+ switch *l {
+ case NoLeakChecking:
+ return "disabled"
+ case LeaksLogWarning:
+ return "log-names"
+ case LeaksLogTraces:
+ return "log-traces"
+ }
+ panic(fmt.Sprintf("invalid ref leak mode %q", *l))
+}
+
// leakMode stores the current mode for the reference leak checker.
//
// Values must be one of the LeakMode values.
diff --git a/pkg/sentry/platform/kvm/virtual_map.go b/pkg/sentry/platform/kvm/virtual_map.go
index c8897d34f..4dcdbf8a7 100644
--- a/pkg/sentry/platform/kvm/virtual_map.go
+++ b/pkg/sentry/platform/kvm/virtual_map.go
@@ -34,7 +34,7 @@ type virtualRegion struct {
}
// mapsLine matches a single line from /proc/PID/maps.
-var mapsLine = regexp.MustCompile("([0-9a-f]+)-([0-9a-f]+) ([r-][w-][x-][sp]) ([0-9a-f]+) [0-9a-f]{2}:[0-9a-f]{2,} [0-9]+\\s+(.*)")
+var mapsLine = regexp.MustCompile("([0-9a-f]+)-([0-9a-f]+) ([r-][w-][x-][sp]) ([0-9a-f]+) [0-9a-f]{2,3}:[0-9a-f]{2,} [0-9]+\\s+(.*)")
// excludeRegion returns true if these regions should be excluded from the
// physical map. Virtual regions need to be excluded if get_user_pages will
diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD
index 795620589..8aea0200f 100644
--- a/pkg/sentry/socket/netfilter/BUILD
+++ b/pkg/sentry/socket/netfilter/BUILD
@@ -7,6 +7,7 @@ go_library(
srcs = [
"extensions.go",
"ipv4.go",
+ "ipv6.go",
"netfilter.go",
"owner_matcher.go",
"targets.go",
diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go
index 4fb887e49..e4c55a100 100644
--- a/pkg/sentry/socket/netfilter/ipv4.go
+++ b/pkg/sentry/socket/netfilter/ipv4.go
@@ -36,14 +36,37 @@ var emptyIPv4Filter = stack.IPHeaderFilter{
SrcMask: "\x00\x00\x00\x00",
}
-func getEntries4(table stack.Table, info *linux.IPTGetinfo) linux.KernelIPTGetEntries {
+// convertNetstackToBinary4 converts the iptables as stored in netstack to the
+// format expected by the iptables tool. Linux stores each table as a binary
+// blob that can only be traversed by parsing a little data, reading some
+// offsets, jumping to those offsets, parsing again, etc.
+func convertNetstackToBinary4(stk *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) {
+ // The table name has to fit in the struct.
+ if linux.XT_TABLE_MAXNAMELEN < len(tablename) {
+ return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
+ }
+
+ table, ok := stk.IPTables().GetTable(tablename.String(), false)
+ if !ok {
+ return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
+ }
+
+ // Setup the info struct.
+ entries, info := getEntries4(table, tablename)
+ return entries, info, nil
+}
+
+func getEntries4(table stack.Table, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo) {
+ var info linux.IPTGetinfo
var entries linux.KernelIPTGetEntries
+ copy(info.Name[:], tablename[:])
copy(entries.Name[:], info.Name[:])
+ info.ValidHooks = table.ValidHooks()
for ruleIdx, rule := range table.Rules {
nflog("convert to binary: current offset: %d", entries.Size)
- setHooksAndUnderflow(info, table, entries.Size, ruleIdx)
+ setHooksAndUnderflow(&info, table, entries.Size, ruleIdx)
// Each rule corresponds to an entry.
entry := linux.KernelIPTEntry{
Entry: linux.IPTEntry{
@@ -100,7 +123,7 @@ func getEntries4(table stack.Table, info *linux.IPTGetinfo) linux.KernelIPTGetEn
info.Size = entries.Size
nflog("convert to binary: finished with an marshalled size of %d", info.Size)
- return entries
+ return entries, info
}
func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) {
@@ -205,7 +228,9 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) {
ifnameMask := string(iptip.OutputInterfaceMask[:n])
return stack.IPHeaderFilter{
- Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ // A Protocol value of 0 indicates all protocols match.
+ CheckProtocol: iptip.Protocol != 0,
Dst: tcpip.Address(iptip.Dst[:]),
DstMask: tcpip.Address(iptip.DstMask[:]),
DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0,
diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go
new file mode 100644
index 000000000..3b2c1becd
--- /dev/null
+++ b/pkg/sentry/socket/netfilter/ipv6.go
@@ -0,0 +1,265 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package netfilter
+
+import (
+ "bytes"
+ "fmt"
+
+ "gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/binary"
+ "gvisor.dev/gvisor/pkg/syserr"
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/header"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// emptyIPv6Filter is for comparison with a rule's filters to determine whether
+// it is also empty. It is immutable.
+var emptyIPv6Filter = stack.IPHeaderFilter{
+ Dst: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ DstMask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ Src: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+ SrcMask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00",
+}
+
+// convertNetstackToBinary6 converts the ip6tables as stored in netstack to the
+// format expected by the iptables tool. Linux stores each table as a binary
+// blob that can only be traversed by parsing a little data, reading some
+// offsets, jumping to those offsets, parsing again, etc.
+func convertNetstackToBinary6(stk *stack.Stack, tablename linux.TableName) (linux.KernelIP6TGetEntries, linux.IPTGetinfo, error) {
+ // The table name has to fit in the struct.
+ if linux.XT_TABLE_MAXNAMELEN < len(tablename) {
+ return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
+ }
+
+ table, ok := stk.IPTables().GetTable(tablename.String(), true)
+ if !ok {
+ return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
+ }
+
+ // Setup the info struct, which is the same in IPv4 and IPv6.
+ entries, info := getEntries6(table, tablename)
+ return entries, info, nil
+}
+
+func getEntries6(table stack.Table, tablename linux.TableName) (linux.KernelIP6TGetEntries, linux.IPTGetinfo) {
+ var info linux.IPTGetinfo
+ var entries linux.KernelIP6TGetEntries
+ copy(info.Name[:], tablename[:])
+ copy(entries.Name[:], info.Name[:])
+ info.ValidHooks = table.ValidHooks()
+
+ for ruleIdx, rule := range table.Rules {
+ nflog("convert to binary: current offset: %d", entries.Size)
+
+ setHooksAndUnderflow(&info, table, entries.Size, ruleIdx)
+ // Each rule corresponds to an entry.
+ entry := linux.KernelIP6TEntry{
+ Entry: linux.IP6TEntry{
+ IPv6: linux.IP6TIP{
+ Protocol: uint16(rule.Filter.Protocol),
+ },
+ NextOffset: linux.SizeOfIP6TEntry,
+ TargetOffset: linux.SizeOfIP6TEntry,
+ },
+ }
+ copy(entry.Entry.IPv6.Dst[:], rule.Filter.Dst)
+ copy(entry.Entry.IPv6.DstMask[:], rule.Filter.DstMask)
+ copy(entry.Entry.IPv6.Src[:], rule.Filter.Src)
+ copy(entry.Entry.IPv6.SrcMask[:], rule.Filter.SrcMask)
+ copy(entry.Entry.IPv6.OutputInterface[:], rule.Filter.OutputInterface)
+ copy(entry.Entry.IPv6.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask)
+ if rule.Filter.DstInvert {
+ entry.Entry.IPv6.InverseFlags |= linux.IP6T_INV_DSTIP
+ }
+ if rule.Filter.SrcInvert {
+ entry.Entry.IPv6.InverseFlags |= linux.IP6T_INV_SRCIP
+ }
+ if rule.Filter.OutputInterfaceInvert {
+ entry.Entry.IPv6.InverseFlags |= linux.IP6T_INV_VIA_OUT
+ }
+ if rule.Filter.CheckProtocol {
+ entry.Entry.IPv6.Flags |= linux.IP6T_F_PROTO
+ }
+
+ for _, matcher := range rule.Matchers {
+ // Serialize the matcher and add it to the
+ // entry.
+ serialized := marshalMatcher(matcher)
+ nflog("convert to binary: matcher serialized as: %v", serialized)
+ if len(serialized)%8 != 0 {
+ panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher))
+ }
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.Entry.NextOffset += uint16(len(serialized))
+ entry.Entry.TargetOffset += uint16(len(serialized))
+ }
+
+ // Serialize and append the target.
+ serialized := marshalTarget(rule.Target)
+ if len(serialized)%8 != 0 {
+ panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target))
+ }
+ entry.Elems = append(entry.Elems, serialized...)
+ entry.Entry.NextOffset += uint16(len(serialized))
+
+ nflog("convert to binary: adding entry: %+v", entry)
+
+ entries.Size += uint32(entry.Entry.NextOffset)
+ entries.Entrytable = append(entries.Entrytable, entry)
+ info.NumEntries++
+ }
+
+ info.Size = entries.Size
+ nflog("convert to binary: finished with an marshalled size of %d", info.Size)
+ return entries, info
+}
+
+func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) {
+ nflog("set entries: setting entries in table %q", replace.Name.String())
+
+ // Convert input into a list of rules and their offsets.
+ var offset uint32
+ // offsets maps rule byte offsets to their position in table.Rules.
+ offsets := map[uint32]int{}
+ for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ {
+ nflog("set entries: processing entry at offset %d", offset)
+
+ // Get the struct ipt_entry.
+ if len(optVal) < linux.SizeOfIP6TEntry {
+ nflog("optVal has insufficient size for entry %d", len(optVal))
+ return nil, syserr.ErrInvalidArgument
+ }
+ var entry linux.IP6TEntry
+ buf := optVal[:linux.SizeOfIP6TEntry]
+ binary.Unmarshal(buf, usermem.ByteOrder, &entry)
+ initialOptValLen := len(optVal)
+ optVal = optVal[linux.SizeOfIP6TEntry:]
+
+ if entry.TargetOffset < linux.SizeOfIP6TEntry {
+ nflog("entry has too-small target offset %d", entry.TargetOffset)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): We should support more IPTIP
+ // filtering fields.
+ filter, err := filterFromIP6TIP(entry.IPv6)
+ if err != nil {
+ nflog("bad iptip: %v", err)
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // TODO(gvisor.dev/issue/170): Matchers and targets can specify
+ // that they only work for certain protocols, hooks, tables.
+ // Get matchers.
+ matchersSize := entry.TargetOffset - linux.SizeOfIP6TEntry
+ if len(optVal) < int(matchersSize) {
+ nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal))
+ return nil, syserr.ErrInvalidArgument
+ }
+ matchers, err := parseMatchers(filter, optVal[:matchersSize])
+ if err != nil {
+ nflog("failed to parse matchers: %v", err)
+ return nil, syserr.ErrInvalidArgument
+ }
+ optVal = optVal[matchersSize:]
+
+ // Get the target of the rule.
+ targetSize := entry.NextOffset - entry.TargetOffset
+ if len(optVal) < int(targetSize) {
+ nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal))
+ return nil, syserr.ErrInvalidArgument
+ }
+ target, err := parseTarget(filter, optVal[:targetSize])
+ if err != nil {
+ nflog("failed to parse target: %v", err)
+ return nil, syserr.ErrInvalidArgument
+ }
+ optVal = optVal[targetSize:]
+
+ table.Rules = append(table.Rules, stack.Rule{
+ Filter: filter,
+ Target: target,
+ Matchers: matchers,
+ })
+ offsets[offset] = int(entryIdx)
+ offset += uint32(entry.NextOffset)
+
+ if initialOptValLen-len(optVal) != int(entry.NextOffset) {
+ nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal))
+ return nil, syserr.ErrInvalidArgument
+ }
+ }
+ return offsets, nil
+}
+
+func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) {
+ if containsUnsupportedFields6(iptip) {
+ return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip)
+ }
+ if len(iptip.Dst) != header.IPv6AddressSize || len(iptip.DstMask) != header.IPv6AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask))
+ }
+ if len(iptip.Src) != header.IPv6AddressSize || len(iptip.SrcMask) != header.IPv6AddressSize {
+ return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask))
+ }
+
+ n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0)
+ if n == -1 {
+ n = len(iptip.OutputInterface)
+ }
+ ifname := string(iptip.OutputInterface[:n])
+
+ n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0)
+ if n == -1 {
+ n = len(iptip.OutputInterfaceMask)
+ }
+ ifnameMask := string(iptip.OutputInterfaceMask[:n])
+
+ return stack.IPHeaderFilter{
+ Protocol: tcpip.TransportProtocolNumber(iptip.Protocol),
+ // In ip6tables a flag controls whether to check the protocol.
+ CheckProtocol: iptip.Flags&linux.IP6T_F_PROTO != 0,
+ Dst: tcpip.Address(iptip.Dst[:]),
+ DstMask: tcpip.Address(iptip.DstMask[:]),
+ DstInvert: iptip.InverseFlags&linux.IP6T_INV_DSTIP != 0,
+ Src: tcpip.Address(iptip.Src[:]),
+ SrcMask: tcpip.Address(iptip.SrcMask[:]),
+ SrcInvert: iptip.InverseFlags&linux.IP6T_INV_SRCIP != 0,
+ OutputInterface: ifname,
+ OutputInterfaceMask: ifnameMask,
+ OutputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_OUT != 0,
+ }, nil
+}
+
+func containsUnsupportedFields6(iptip linux.IP6TIP) bool {
+ // The following features are supported:
+ // - Protocol
+ // - Dst and DstMask
+ // - Src and SrcMask
+ // - The inverse destination IP check flag
+ // - OutputInterface, OutputInterfaceMask and its inverse.
+ var emptyInterface = [linux.IFNAMSIZ]byte{}
+ flagMask := uint8(linux.IP6T_F_PROTO)
+ // Disable any supported inverse flags.
+ inverseMask := uint8(linux.IP6T_INV_DSTIP) | uint8(linux.IP6T_INV_SRCIP) | uint8(linux.IP6T_INV_VIA_OUT)
+ return iptip.InputInterface != emptyInterface ||
+ iptip.InputInterfaceMask != emptyInterface ||
+ iptip.Flags&^flagMask != 0 ||
+ iptip.InverseFlags&^inverseMask != 0 ||
+ iptip.TOS != 0
+}
diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go
index df256676f..3e1735079 100644
--- a/pkg/sentry/socket/netfilter/netfilter.go
+++ b/pkg/sentry/socket/netfilter/netfilter.go
@@ -42,14 +42,19 @@ func nflog(format string, args ...interface{}) {
}
// GetInfo returns information about iptables.
-func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) {
+func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) {
// Read in the struct and table name.
var info linux.IPTGetinfo
if _, err := info.CopyIn(t, outPtr); err != nil {
return linux.IPTGetinfo{}, syserr.FromError(err)
}
- _, info, err := convertNetstackToBinary(stack, info.Name)
+ var err error
+ if ipv6 {
+ _, info, err = convertNetstackToBinary6(stack, info.Name)
+ } else {
+ _, info, err = convertNetstackToBinary4(stack, info.Name)
+ }
if err != nil {
nflog("couldn't convert iptables: %v", err)
return linux.IPTGetinfo{}, syserr.ErrInvalidArgument
@@ -59,9 +64,9 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT
return info, nil
}
-// GetEntries4 returns netstack's iptables rules encoded for the iptables tool.
+// GetEntries4 returns netstack's iptables rules.
func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) {
- // Read in the ABI struct.
+ // Read in the struct and table name.
var userEntries linux.IPTGetEntries
if _, err := userEntries.CopyIn(t, outPtr); err != nil {
nflog("couldn't copy in entries %q", userEntries.Name)
@@ -70,7 +75,7 @@ func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen
// Convert netstack's iptables rules to something that the iptables
// tool can understand.
- entries, _, err := convertNetstackToBinary(stack, userEntries.Name)
+ entries, _, err := convertNetstackToBinary4(stack, userEntries.Name)
if err != nil {
nflog("couldn't read entries: %v", err)
return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument
@@ -83,28 +88,29 @@ func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen
return entries, nil
}
-// convertNetstackToBinary converts the iptables as stored in netstack to the
-// format expected by the iptables tool. Linux stores each table as a binary
-// blob that can only be traversed by parsing a bit, reading some offsets,
-// jumping to those offsets, parsing again, etc.
-func convertNetstackToBinary(stk *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) {
- // The table name has to fit in the struct.
- if linux.XT_TABLE_MAXNAMELEN < len(tablename) {
- return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename)
+// GetEntries6 returns netstack's ip6tables rules.
+func GetEntries6(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIP6TGetEntries, *syserr.Error) {
+ // Read in the struct and table name. IPv4 and IPv6 utilize structs
+ // with the same layout.
+ var userEntries linux.IPTGetEntries
+ if _, err := userEntries.CopyIn(t, outPtr); err != nil {
+ nflog("couldn't copy in entries %q", userEntries.Name)
+ return linux.KernelIP6TGetEntries{}, syserr.FromError(err)
}
- table, ok := stk.IPTables().GetTable(tablename.String())
- if !ok {
- return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename)
+ // Convert netstack's iptables rules to something that the iptables
+ // tool can understand.
+ entries, _, err := convertNetstackToBinary6(stack, userEntries.Name)
+ if err != nil {
+ nflog("couldn't read entries: %v", err)
+ return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument
+ }
+ if binary.Size(entries) > uintptr(outLen) {
+ nflog("insufficient GetEntries output size: %d", uintptr(outLen))
+ return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument
}
- // Setup the info struct.
- var info linux.IPTGetinfo
- info.ValidHooks = table.ValidHooks()
- copy(info.Name[:], tablename[:])
-
- entries := getEntries4(table, &info)
- return entries, info, nil
+ return entries, nil
}
// setHooksAndUnderflow checks whether the rule at ruleIdx is a hook entrypoint
@@ -128,7 +134,7 @@ func setHooksAndUnderflow(info *linux.IPTGetinfo, table stack.Table, offset uint
// SetEntries sets iptables rules for a single table. See
// net/ipv4/netfilter/ip_tables.c:translate_table for reference.
-func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
+func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error {
var replace linux.IPTReplace
replaceBuf := optVal[:linux.SizeOfIPTReplace]
optVal = optVal[linux.SizeOfIPTReplace:]
@@ -146,7 +152,13 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
return syserr.ErrInvalidArgument
}
- offsets, err := modifyEntries4(stk, optVal, &replace, &table)
+ var err *syserr.Error
+ var offsets map[uint32]int
+ if ipv6 {
+ offsets, err = modifyEntries6(stk, optVal, &replace, &table)
+ } else {
+ offsets, err = modifyEntries4(stk, optVal, &replace, &table)
+ }
if err != nil {
return err
}
@@ -163,7 +175,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
table.BuiltinChains[hk] = ruleIdx
}
if offset == replace.Underflow[hook] {
- if !validUnderflow(table.Rules[ruleIdx]) {
+ if !validUnderflow(table.Rules[ruleIdx], ipv6) {
nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx)
return syserr.ErrInvalidArgument
}
@@ -228,7 +240,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
if ruleIdx == stack.HookUnset {
continue
}
- if !isUnconditionalAccept(table.Rules[ruleIdx]) {
+ if !isUnconditionalAccept(table.Rules[ruleIdx], ipv6) {
nflog("hook %d is unsupported.", hook)
return syserr.ErrInvalidArgument
}
@@ -240,7 +252,8 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error {
// - There are no chains without an unconditional final rule.
// - There are no chains without an unconditional underflow rule.
- return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table))
+ return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table, ipv6))
+
}
// parseMatchers parses 0 or more matchers from optVal. optVal should contain
@@ -286,11 +299,11 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher,
return matchers, nil
}
-func validUnderflow(rule stack.Rule) bool {
+func validUnderflow(rule stack.Rule, ipv6 bool) bool {
if len(rule.Matchers) != 0 {
return false
}
- if rule.Filter != emptyIPv4Filter {
+ if (ipv6 && rule.Filter != emptyIPv6Filter) || (!ipv6 && rule.Filter != emptyIPv4Filter) {
return false
}
switch rule.Target.(type) {
@@ -301,8 +314,8 @@ func validUnderflow(rule stack.Rule) bool {
}
}
-func isUnconditionalAccept(rule stack.Rule) bool {
- if !validUnderflow(rule) {
+func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool {
+ if !validUnderflow(rule, ipv6) {
return false
}
_, ok := rule.Target.(stack.AcceptTarget)
diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go
index 8ebdaff18..87e41abd8 100644
--- a/pkg/sentry/socket/netfilter/targets.go
+++ b/pkg/sentry/socket/netfilter/targets.go
@@ -218,8 +218,8 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, erro
return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal))
}
- if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber {
- return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber {
+ return nil, fmt.Errorf("netfilter.SetEntries: bad proto %d", p)
}
var redirectTarget linux.XTRedirectTarget
@@ -232,7 +232,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, erro
// RangeSize should be 1.
if nfRange.RangeSize != 1 {
- return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ return nil, fmt.Errorf("netfilter.SetEntries: bad rangesize %d", nfRange.RangeSize)
}
// TODO(gvisor.dev/issue/170): Check if the flags are valid.
@@ -240,7 +240,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, erro
// For now, redirect target only supports destination port change.
// Port range and IP range are not supported yet.
if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 {
- return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ return nil, fmt.Errorf("netfilter.SetEntries: invalid range flags %d", nfRange.RangeIPV4.Flags)
}
target.RangeProtoSpecified = true
@@ -249,7 +249,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, erro
// TODO(gvisor.dev/issue/170): Port range is not supported yet.
if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort {
- return nil, fmt.Errorf("netfilter.SetEntries: invalid argument")
+ return nil, fmt.Errorf("netfilter.SetEntries: minport != maxport (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort)
}
// Convert port from big endian to little endian.
diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go
index 9e2ebc7d4..8da77cc68 100644
--- a/pkg/sentry/socket/netstack/netstack.go
+++ b/pkg/sentry/socket/netstack/netstack.go
@@ -257,6 +257,9 @@ type commonEndpoint interface {
// GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and
// transport.Endpoint.GetSockOpt.
GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error)
+
+ // LastError implements tcpip.Endpoint.LastError.
+ LastError() *tcpip.Error
}
// LINT.IfChange
@@ -997,7 +1000,7 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in
return getSockOptTCP(t, ep, name, outLen)
case linux.SOL_IPV6:
- return getSockOptIPv6(t, ep, name, outLen)
+ return getSockOptIPv6(t, s, ep, name, outPtr, outLen)
case linux.SOL_IP:
return getSockOptIP(t, s, ep, name, outPtr, outLen, family)
@@ -1030,7 +1033,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam
}
// Get the last error and convert it.
- err := ep.GetSockOpt(tcpip.ErrorOption{})
+ err := ep.LastError()
if err == nil {
optP := primitive.Int32(0)
return &optP, nil
@@ -1455,7 +1458,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal
}
// getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6.
-func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) {
+func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) {
switch name {
case linux.IPV6_V6ONLY:
if outLen < sizeOfInt32 {
@@ -1508,10 +1511,50 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marsha
vP := primitive.Int32(boolToInt32(v))
return &vP, nil
- case linux.SO_ORIGINAL_DST:
+ case linux.IP6T_ORIGINAL_DST:
// TODO(gvisor.dev/issue/170): ip6tables.
return nil, syserr.ErrInvalidArgument
+ case linux.IP6T_SO_GET_INFO:
+ if outLen < linux.SizeOfIPTGetinfo {
+ return nil, syserr.ErrInvalidArgument
+ }
+
+ // Only valid for raw IPv6 sockets.
+ if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ return nil, syserr.ErrProtocolNotAvailable
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, true)
+ if err != nil {
+ return nil, err
+ }
+ return &info, nil
+
+ case linux.IP6T_SO_GET_ENTRIES:
+ // IPTGetEntries is reused for IPv6.
+ if outLen < linux.SizeOfIPTGetEntries {
+ return nil, syserr.ErrInvalidArgument
+ }
+ // Only valid for raw IPv6 sockets.
+ if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ return nil, syserr.ErrProtocolNotAvailable
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return nil, syserr.ErrNoDevice
+ }
+ entries, err := netfilter.GetEntries6(t, stack.(*Stack).Stack, outPtr, outLen)
+ if err != nil {
+ return nil, err
+ }
+ return &entries, nil
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -1649,7 +1692,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
if stack == nil {
return nil, syserr.ErrNoDevice
}
- info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr)
+ info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, false)
if err != nil {
return nil, err
}
@@ -1722,7 +1765,7 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int
return setSockOptTCP(t, ep, name, optVal)
case linux.SOL_IPV6:
- return setSockOptIPv6(t, ep, name, optVal)
+ return setSockOptIPv6(t, s, ep, name, optVal)
case linux.SOL_IP:
return setSockOptIP(t, s, ep, name, optVal)
@@ -2027,7 +2070,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *
}
// setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6.
-func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
+func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error {
switch name {
case linux.IPV6_V6ONLY:
if len(optVal) < sizeOfInt32 {
@@ -2076,6 +2119,27 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte)
return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0))
+ case linux.IP6T_SO_SET_REPLACE:
+ if len(optVal) < linux.SizeOfIP6TReplace {
+ return syserr.ErrInvalidArgument
+ }
+
+ // Only valid for raw IPv6 sockets.
+ if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW {
+ return syserr.ErrProtocolNotAvailable
+ }
+
+ stack := inet.StackFromContext(t)
+ if stack == nil {
+ return syserr.ErrNoDevice
+ }
+ // Stack must be a netstack stack.
+ return netfilter.SetEntries(stack.(*Stack).Stack, optVal, true)
+
+ case linux.IP6T_SO_SET_ADD_COUNTERS:
+ // TODO(gvisor.dev/issue/170): Counter support.
+ return nil
+
default:
emitUnimplementedEventIPv6(t, name)
}
@@ -2271,7 +2335,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in
return syserr.ErrNoDevice
}
// Stack must be a netstack stack.
- return netfilter.SetEntries(stack.(*Stack).Stack, optVal)
+ return netfilter.SetEntries(stack.(*Stack).Stack, optVal, false)
case linux.IPT_SO_SET_ADD_COUNTERS:
// TODO(gvisor.dev/issue/170): Counter support.
diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go
index ab7bab5cd..4bf06d4dc 100644
--- a/pkg/sentry/socket/unix/transport/unix.go
+++ b/pkg/sentry/socket/unix/transport/unix.go
@@ -199,6 +199,9 @@ type Endpoint interface {
// State returns the current state of the socket, as represented by Linux in
// procfs.
State() uint32
+
+ // LastError implements tcpip.Endpoint.LastError.
+ LastError() *tcpip.Error
}
// A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket
@@ -942,7 +945,7 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch opt.(type) {
- case tcpip.ErrorOption, *tcpip.LingerOption:
+ case *tcpip.LingerOption:
return nil
default:
@@ -951,6 +954,11 @@ func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
}
}
+// LastError implements Endpoint.LastError.
+func (*baseEndpoint) LastError() *tcpip.Error {
+ return nil
+}
+
// Shutdown closes the read and/or write end of the endpoint connection to its
// peer.
func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *syserr.Error {
diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go
index 748273366..bbafb8b7f 100644
--- a/pkg/sentry/watchdog/watchdog.go
+++ b/pkg/sentry/watchdog/watchdog.go
@@ -96,15 +96,33 @@ const (
Panic
)
+// Set implements flag.Value.
+func (a *Action) Set(v string) error {
+ switch v {
+ case "log", "logwarning":
+ *a = LogWarning
+ case "panic":
+ *a = Panic
+ default:
+ return fmt.Errorf("invalid watchdog action %q", v)
+ }
+ return nil
+}
+
+// Get implements flag.Value.
+func (a *Action) Get() interface{} {
+ return *a
+}
+
// String returns Action's string representation.
-func (a Action) String() string {
- switch a {
+func (a *Action) String() string {
+ switch *a {
case LogWarning:
- return "LogWarning"
+ return "logWarning"
case Panic:
- return "Panic"
+ return "panic"
default:
- panic(fmt.Sprintf("Invalid action: %d", a))
+ panic(fmt.Sprintf("Invalid watchdog action: %d", *a))
}
}
diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go
index d82ed5205..68a954a10 100644
--- a/pkg/tcpip/adapters/gonet/gonet.go
+++ b/pkg/tcpip/adapters/gonet/gonet.go
@@ -541,7 +541,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress,
case <-notifyCh:
}
- err = ep.GetSockOpt(tcpip.ErrorOption{})
+ err = ep.LastError()
}
if err != nil {
ep.Close()
diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go
index 3c552988a..c975ad9cf 100644
--- a/pkg/tcpip/adapters/gonet/gonet_test.go
+++ b/pkg/tcpip/adapters/gonet/gonet_test.go
@@ -104,7 +104,7 @@ func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Er
err = ep.Connect(addr)
if err == tcpip.ErrConnectStarted {
<-ch
- err = ep.GetSockOpt(tcpip.ErrorOption{})
+ err = ep.LastError()
}
if err != nil {
return nil, err
diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go
index 0ab089208..91fc26722 100644
--- a/pkg/tcpip/sample/tun_tcp_connect/main.go
+++ b/pkg/tcpip/sample/tun_tcp_connect/main.go
@@ -182,7 +182,7 @@ func main() {
if terr == tcpip.ErrConnectStarted {
fmt.Println("Connect is pending...")
<-notifyCh
- terr = ep.GetSockOpt(tcpip.ErrorOption{})
+ terr = ep.LastError()
}
wq.EventUnregister(&waitEntry)
diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go
index 41ef4236b..30aa41db2 100644
--- a/pkg/tcpip/stack/iptables.go
+++ b/pkg/tcpip/stack/iptables.go
@@ -165,7 +165,11 @@ func EmptyNATTable() Table {
}
// GetTable returns a table by name.
-func (it *IPTables) GetTable(name string) (Table, bool) {
+func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) {
+ // TODO(gvisor.dev/issue/3549): Enable IPv6.
+ if ipv6 {
+ return Table{}, false
+ }
id, ok := nameToID[name]
if !ok {
return Table{}, false
@@ -176,7 +180,11 @@ func (it *IPTables) GetTable(name string) (Table, bool) {
}
// ReplaceTable replaces or inserts table by name.
-func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error {
+func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error {
+ // TODO(gvisor.dev/issue/3549): Enable IPv6.
+ if ipv6 {
+ return tcpip.ErrInvalidOptionValue
+ }
id, ok := nameToID[name]
if !ok {
return tcpip.ErrInvalidOptionValue
diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go
index 73274ada9..fbbd2f50f 100644
--- a/pkg/tcpip/stack/iptables_types.go
+++ b/pkg/tcpip/stack/iptables_types.go
@@ -155,6 +155,11 @@ type IPHeaderFilter struct {
// Protocol matches the transport protocol.
Protocol tcpip.TransportProtocolNumber
+ // CheckProtocol determines whether the Protocol field should be
+ // checked during matching.
+ // TODO(gvisor.dev/issue/3549): Check this field during matching.
+ CheckProtocol bool
+
// Dst matches the destination IP address.
Dst tcpip.Address
diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go
index 6c6e44468..7869bb98b 100644
--- a/pkg/tcpip/stack/transport_test.go
+++ b/pkg/tcpip/stack/transport_test.go
@@ -53,11 +53,11 @@ func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo {
return &f.TransportEndpointInfo
}
-func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats {
+func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats {
return nil
}
-func (f *fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
+func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {}
func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint {
return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID}
@@ -100,7 +100,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions
return int64(len(v)), nil, nil
}
-func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
+func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) {
return 0, tcpip.ControlMessages{}, nil
}
@@ -131,10 +131,6 @@ func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.E
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch opt.(type) {
- case tcpip.ErrorOption:
- return nil
- }
return tcpip.ErrInvalidEndpointState
}
@@ -169,7 +165,7 @@ func (f *fakeTransportEndpoint) UniqueID() uint64 {
return f.uniqueID
}
-func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
+func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error {
return nil
}
@@ -239,19 +235,19 @@ func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, s
f.proto.controlCount++
}
-func (f *fakeTransportEndpoint) State() uint32 {
+func (*fakeTransportEndpoint) State() uint32 {
return 0
}
-func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
+func (*fakeTransportEndpoint) ModerateRecvBuf(copied int) {}
-func (f *fakeTransportEndpoint) IPTables() (stack.IPTables, error) {
- return stack.IPTables{}, nil
-}
+func (*fakeTransportEndpoint) Resume(*stack.Stack) {}
-func (f *fakeTransportEndpoint) Resume(*stack.Stack) {}
+func (*fakeTransportEndpoint) Wait() {}
-func (f *fakeTransportEndpoint) Wait() {}
+func (*fakeTransportEndpoint) LastError() *tcpip.Error {
+ return nil
+}
type fakeTransportGoodOption bool
diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go
index 609b8af33..cae943608 100644
--- a/pkg/tcpip/tcpip.go
+++ b/pkg/tcpip/tcpip.go
@@ -620,6 +620,9 @@ type Endpoint interface {
// SetOwner sets the task owner to the endpoint owner.
SetOwner(owner PacketOwner)
+
+ // LastError clears and returns the last error reported by the endpoint.
+ LastError() *Error
}
// LinkPacketInfo holds Link layer information for a received packet.
@@ -839,10 +842,6 @@ const (
PMTUDiscoveryProbe
)
-// ErrorOption is used in GetSockOpt to specify that the last error reported by
-// the endpoint should be cleared and returned.
-type ErrorOption struct{}
-
// BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets
// should bind only on a specific NIC.
type BindToDeviceOption NICID
diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go
index bd6f49eb8..c545c8367 100644
--- a/pkg/tcpip/transport/icmp/endpoint.go
+++ b/pkg/tcpip/transport/icmp/endpoint.go
@@ -415,14 +415,8 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch opt.(type) {
- case tcpip.ErrorOption:
- return nil
-
- default:
- return tcpip.ErrUnknownProtocolOption
- }
+func (*endpoint) GetSockOpt(interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
}
func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error {
@@ -836,3 +830,8 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements stack.TransportEndpoint.Wait.
func (*endpoint) Wait() {}
+
+// LastError implements tcpip.Endpoint.LastError.
+func (*endpoint) LastError() *tcpip.Error {
+ return nil
+}
diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go
index 1b03ad6bb..95dc8ed57 100644
--- a/pkg/tcpip/transport/packet/endpoint.go
+++ b/pkg/tcpip/transport/packet/endpoint.go
@@ -356,7 +356,7 @@ func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
}
}
-func (ep *endpoint) takeLastError() *tcpip.Error {
+func (ep *endpoint) LastError() *tcpip.Error {
ep.lastErrorMu.Lock()
defer ep.lastErrorMu.Unlock()
@@ -366,11 +366,7 @@ func (ep *endpoint) takeLastError() *tcpip.Error {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch opt.(type) {
- case tcpip.ErrorOption:
- return ep.takeLastError()
- }
+func (*endpoint) GetSockOpt(interface{}) *tcpip.Error {
return tcpip.ErrNotSupported
}
diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go
index edc2b5b61..2087bcfa8 100644
--- a/pkg/tcpip/transport/raw/endpoint.go
+++ b/pkg/tcpip/transport/raw/endpoint.go
@@ -577,14 +577,8 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error {
}
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
-func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
- switch opt.(type) {
- case tcpip.ErrorOption:
- return nil
-
- default:
- return tcpip.ErrUnknownProtocolOption
- }
+func (*endpoint) GetSockOpt(interface{}) *tcpip.Error {
+ return tcpip.ErrUnknownProtocolOption
}
// GetSockOptBool implements tcpip.Endpoint.GetSockOptBool.
@@ -739,3 +733,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats {
// Wait implements stack.TransportEndpoint.Wait.
func (*endpoint) Wait() {}
+
+func (*endpoint) LastError() *tcpip.Error {
+ return nil
+}
diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go
index 290172ac9..72df5c2a1 100644
--- a/pkg/tcpip/transport/tcp/connect.go
+++ b/pkg/tcpip/transport/tcp/connect.go
@@ -491,7 +491,7 @@ func (h *handshake) resolveRoute() *tcpip.Error {
h.ep.mu.Lock()
}
if n&notifyError != 0 {
- return h.ep.takeLastError()
+ return h.ep.LastError()
}
}
@@ -620,7 +620,7 @@ func (h *handshake) execute() *tcpip.Error {
h.ep.mu.Lock()
}
if n&notifyError != 0 {
- return h.ep.takeLastError()
+ return h.ep.LastError()
}
case wakerForNewSegment:
diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go
index 804e95aea..6074cc24e 100644
--- a/pkg/tcpip/transport/tcp/dual_stack_test.go
+++ b/pkg/tcpip/transport/tcp/dual_stack_test.go
@@ -86,8 +86,7 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network
// Wait for connection to be established.
select {
case <-ch:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
+ if err := c.EP.LastError(); err != nil {
t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
@@ -194,8 +193,7 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network
// Wait for connection to be established.
select {
case <-ch:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
+ if err := c.EP.LastError(); err != nil {
t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go
index ff9b8804d..8a5e993b5 100644
--- a/pkg/tcpip/transport/tcp/endpoint.go
+++ b/pkg/tcpip/transport/tcp/endpoint.go
@@ -1234,7 +1234,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) {
e.owner = owner
}
-func (e *endpoint) takeLastError() *tcpip.Error {
+func (e *endpoint) LastError() *tcpip.Error {
e.lastErrorMu.Lock()
defer e.lastErrorMu.Unlock()
err := e.lastError
@@ -1995,9 +1995,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
- case tcpip.ErrorOption:
- return e.takeLastError()
-
case *tcpip.BindToDeviceOption:
e.LockUser()
*o = tcpip.BindToDeviceOption(e.bindToDevice)
diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go
index 9650bb06c..3d3034d50 100644
--- a/pkg/tcpip/transport/tcp/tcp_test.go
+++ b/pkg/tcpip/transport/tcp/tcp_test.go
@@ -74,8 +74,8 @@ func TestGiveUpConnect(t *testing.T) {
// Wait for ep to become writable.
<-notifyCh
- if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted {
- t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %s, want = %s", err, tcpip.ErrAborted)
+ if err := ep.LastError(); err != tcpip.ErrAborted {
+ t.Fatalf("got ep.LastError() = %s, want = %s", err, tcpip.ErrAborted)
}
// Call Connect again to retreive the handshake failure status
@@ -3023,8 +3023,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) {
// Wait for connection to be established.
select {
case <-ch:
- if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
- t.Fatalf("GetSockOpt failed: %s", err)
+ if err := c.EP.LastError(); err != nil {
+ t.Fatalf("Connect failed: %s", err)
}
case <-time.After(1 * time.Second):
t.Fatalf("Timed out waiting for connection")
@@ -4411,7 +4411,7 @@ func TestSelfConnect(t *testing.T) {
}
<-notifyCh
- if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil {
+ if err := ep.LastError(); err != nil {
t.Fatalf("Connect failed: %s", err)
}
diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go
index b6031354e..1f5340cd0 100644
--- a/pkg/tcpip/transport/tcp/testing/context/context.go
+++ b/pkg/tcpip/transport/tcp/testing/context/context.go
@@ -638,7 +638,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte)
// Wait for connection to be established.
select {
case <-notifyCh:
- if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil {
+ if err := c.EP.LastError(); err != nil {
c.t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
@@ -882,8 +882,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) *
// Wait for connection to be established.
select {
case <-notifyCh:
- err = c.EP.GetSockOpt(tcpip.ErrorOption{})
- if err != nil {
+ if err := c.EP.LastError(); err != nil {
c.t.Fatalf("Unexpected error when connecting: %v", err)
}
case <-time.After(1 * time.Second):
diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go
index 0a9d3c6cf..1d5ebe3f2 100644
--- a/pkg/tcpip/transport/udp/endpoint.go
+++ b/pkg/tcpip/transport/udp/endpoint.go
@@ -209,7 +209,7 @@ func (e *endpoint) UniqueID() uint64 {
return e.uniqueID
}
-func (e *endpoint) takeLastError() *tcpip.Error {
+func (e *endpoint) LastError() *tcpip.Error {
e.lastErrorMu.Lock()
defer e.lastErrorMu.Unlock()
@@ -268,7 +268,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {}
// Read reads data from the endpoint. This method does not block if
// there is no data pending.
func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) {
- if err := e.takeLastError(); err != nil {
+ if err := e.LastError(); err != nil {
return buffer.View{}, tcpip.ControlMessages{}, err
}
@@ -411,7 +411,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c
}
func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) {
- if err := e.takeLastError(); err != nil {
+ if err := e.LastError(); err != nil {
return 0, nil, err
}
@@ -962,8 +962,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) {
// GetSockOpt implements tcpip.Endpoint.GetSockOpt.
func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error {
switch o := opt.(type) {
- case tcpip.ErrorOption:
- return e.takeLastError()
case *tcpip.MulticastInterfaceOption:
e.mu.Lock()
*o = tcpip.MulticastInterfaceOption{
diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go
index 42d79f5c2..b7f873392 100644
--- a/pkg/test/testutil/testutil.go
+++ b/pkg/test/testutil/testutil.go
@@ -138,20 +138,23 @@ func TestConfig(t *testing.T) *config.Config {
if dir, ok := os.LookupEnv("TEST_UNDECLARED_OUTPUTS_DIR"); ok {
logDir = dir + "/"
}
- return &config.Config{
- Debug: true,
- DebugLog: path.Join(logDir, "runsc.log."+t.Name()+".%TIMESTAMP%.%COMMAND%"),
- LogFormat: "text",
- DebugLogFormat: "text",
- LogPackets: true,
- Network: config.NetworkNone,
- Strace: true,
- Platform: "ptrace",
- FileAccess: config.FileAccessExclusive,
- NumNetworkChannels: 1,
-
- TestOnlyAllowRunAsCurrentUserWithoutChroot: true,
- }
+
+ // Only register flags if config is being used. Otherwise anyone that uses
+ // testutil will get flags registered and they may conflict.
+ config.RegisterFlags()
+
+ conf, err := config.NewFromFlags()
+ if err != nil {
+ panic(err)
+ }
+ // Change test defaults.
+ conf.Debug = true
+ conf.DebugLog = path.Join(logDir, "runsc.log."+t.Name()+".%TIMESTAMP%.%COMMAND%")
+ conf.LogPackets = true
+ conf.Network = config.NetworkNone
+ conf.Strace = true
+ conf.TestOnlyAllowRunAsCurrentUserWithoutChroot = true
+ return conf
}
// NewSpecWithArgs creates a simple spec with the given args suitable for use