diff options
47 files changed, 1698 insertions, 939 deletions
diff --git a/debian/postinst.sh b/debian/postinst.sh index d1e28e17b..6a326f823 100755 --- a/debian/postinst.sh +++ b/debian/postinst.sh @@ -21,7 +21,7 @@ fi # Update docker configuration. if [ -f /etc/docker/daemon.json ]; then runsc install - if systemctl status docker 2>/dev/null; then + if systemctl is-active -q docker; then systemctl restart docker || echo "unable to restart docker; you must do so manually." >&2 fi fi diff --git a/pkg/abi/linux/netfilter_ipv6.go b/pkg/abi/linux/netfilter_ipv6.go index 9bb9efb10..f6117024c 100644 --- a/pkg/abi/linux/netfilter_ipv6.go +++ b/pkg/abi/linux/netfilter_ipv6.go @@ -290,6 +290,19 @@ type IP6TIP struct { const SizeOfIP6TIP = 136 +// Flags in IP6TIP.Flags. Corresponding constants are in +// include/uapi/linux/netfilter_ipv6/ip6_tables.h. +const ( + // Whether to check the Protocol field. + IP6T_F_PROTO = 0x01 + // Whether to match the TOS field. + IP6T_F_TOS = 0x02 + // Indicates that the jump target is an aboslute GOTO, not an offset. + IP6T_F_GOTO = 0x04 + // Enables all flags. + IP6T_F_MASK = 0x07 +) + // Flags in IP6TIP.InverseFlags. Corresponding constants are in // include/uapi/linux/netfilter_ipv6/ip6_tables.h. const ( diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index d9d5e6bcb..57d8542b9 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -234,6 +234,39 @@ const ( LeaksLogTraces ) +// Set implements flag.Value. +func (l *LeakMode) Set(v string) error { + switch v { + case "disabled": + *l = NoLeakChecking + case "log-names": + *l = LeaksLogWarning + case "log-traces": + *l = LeaksLogTraces + default: + return fmt.Errorf("invalid ref leak mode %q", v) + } + return nil +} + +// Get implements flag.Value. +func (l *LeakMode) Get() interface{} { + return *l +} + +// String implements flag.Value. +func (l *LeakMode) String() string { + switch *l { + case NoLeakChecking: + return "disabled" + case LeaksLogWarning: + return "log-names" + case LeaksLogTraces: + return "log-traces" + } + panic(fmt.Sprintf("invalid ref leak mode %q", *l)) +} + // leakMode stores the current mode for the reference leak checker. // // Values must be one of the LeakMode values. diff --git a/pkg/sentry/platform/kvm/virtual_map.go b/pkg/sentry/platform/kvm/virtual_map.go index c8897d34f..4dcdbf8a7 100644 --- a/pkg/sentry/platform/kvm/virtual_map.go +++ b/pkg/sentry/platform/kvm/virtual_map.go @@ -34,7 +34,7 @@ type virtualRegion struct { } // mapsLine matches a single line from /proc/PID/maps. -var mapsLine = regexp.MustCompile("([0-9a-f]+)-([0-9a-f]+) ([r-][w-][x-][sp]) ([0-9a-f]+) [0-9a-f]{2}:[0-9a-f]{2,} [0-9]+\\s+(.*)") +var mapsLine = regexp.MustCompile("([0-9a-f]+)-([0-9a-f]+) ([r-][w-][x-][sp]) ([0-9a-f]+) [0-9a-f]{2,3}:[0-9a-f]{2,} [0-9]+\\s+(.*)") // excludeRegion returns true if these regions should be excluded from the // physical map. Virtual regions need to be excluded if get_user_pages will diff --git a/pkg/sentry/socket/netfilter/BUILD b/pkg/sentry/socket/netfilter/BUILD index 795620589..8aea0200f 100644 --- a/pkg/sentry/socket/netfilter/BUILD +++ b/pkg/sentry/socket/netfilter/BUILD @@ -7,6 +7,7 @@ go_library( srcs = [ "extensions.go", "ipv4.go", + "ipv6.go", "netfilter.go", "owner_matcher.go", "targets.go", diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index 4fb887e49..e4c55a100 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -36,14 +36,37 @@ var emptyIPv4Filter = stack.IPHeaderFilter{ SrcMask: "\x00\x00\x00\x00", } -func getEntries4(table stack.Table, info *linux.IPTGetinfo) linux.KernelIPTGetEntries { +// convertNetstackToBinary4 converts the iptables as stored in netstack to the +// format expected by the iptables tool. Linux stores each table as a binary +// blob that can only be traversed by parsing a little data, reading some +// offsets, jumping to those offsets, parsing again, etc. +func convertNetstackToBinary4(stk *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) { + // The table name has to fit in the struct. + if linux.XT_TABLE_MAXNAMELEN < len(tablename) { + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) + } + + table, ok := stk.IPTables().GetTable(tablename.String(), false) + if !ok { + return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) + } + + // Setup the info struct. + entries, info := getEntries4(table, tablename) + return entries, info, nil +} + +func getEntries4(table stack.Table, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo) { + var info linux.IPTGetinfo var entries linux.KernelIPTGetEntries + copy(info.Name[:], tablename[:]) copy(entries.Name[:], info.Name[:]) + info.ValidHooks = table.ValidHooks() for ruleIdx, rule := range table.Rules { nflog("convert to binary: current offset: %d", entries.Size) - setHooksAndUnderflow(info, table, entries.Size, ruleIdx) + setHooksAndUnderflow(&info, table, entries.Size, ruleIdx) // Each rule corresponds to an entry. entry := linux.KernelIPTEntry{ Entry: linux.IPTEntry{ @@ -100,7 +123,7 @@ func getEntries4(table stack.Table, info *linux.IPTGetinfo) linux.KernelIPTGetEn info.Size = entries.Size nflog("convert to binary: finished with an marshalled size of %d", info.Size) - return entries + return entries, info } func modifyEntries4(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) { @@ -205,7 +228,9 @@ func filterFromIPTIP(iptip linux.IPTIP) (stack.IPHeaderFilter, error) { ifnameMask := string(iptip.OutputInterfaceMask[:n]) return stack.IPHeaderFilter{ - Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), + Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), + // A Protocol value of 0 indicates all protocols match. + CheckProtocol: iptip.Protocol != 0, Dst: tcpip.Address(iptip.Dst[:]), DstMask: tcpip.Address(iptip.DstMask[:]), DstInvert: iptip.InverseFlags&linux.IPT_INV_DSTIP != 0, diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go new file mode 100644 index 000000000..3b2c1becd --- /dev/null +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -0,0 +1,265 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package netfilter + +import ( + "bytes" + "fmt" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/binary" + "gvisor.dev/gvisor/pkg/syserr" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/usermem" +) + +// emptyIPv6Filter is for comparison with a rule's filters to determine whether +// it is also empty. It is immutable. +var emptyIPv6Filter = stack.IPHeaderFilter{ + Dst: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + DstMask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + Src: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", + SrcMask: "\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00", +} + +// convertNetstackToBinary6 converts the ip6tables as stored in netstack to the +// format expected by the iptables tool. Linux stores each table as a binary +// blob that can only be traversed by parsing a little data, reading some +// offsets, jumping to those offsets, parsing again, etc. +func convertNetstackToBinary6(stk *stack.Stack, tablename linux.TableName) (linux.KernelIP6TGetEntries, linux.IPTGetinfo, error) { + // The table name has to fit in the struct. + if linux.XT_TABLE_MAXNAMELEN < len(tablename) { + return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) + } + + table, ok := stk.IPTables().GetTable(tablename.String(), true) + if !ok { + return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) + } + + // Setup the info struct, which is the same in IPv4 and IPv6. + entries, info := getEntries6(table, tablename) + return entries, info, nil +} + +func getEntries6(table stack.Table, tablename linux.TableName) (linux.KernelIP6TGetEntries, linux.IPTGetinfo) { + var info linux.IPTGetinfo + var entries linux.KernelIP6TGetEntries + copy(info.Name[:], tablename[:]) + copy(entries.Name[:], info.Name[:]) + info.ValidHooks = table.ValidHooks() + + for ruleIdx, rule := range table.Rules { + nflog("convert to binary: current offset: %d", entries.Size) + + setHooksAndUnderflow(&info, table, entries.Size, ruleIdx) + // Each rule corresponds to an entry. + entry := linux.KernelIP6TEntry{ + Entry: linux.IP6TEntry{ + IPv6: linux.IP6TIP{ + Protocol: uint16(rule.Filter.Protocol), + }, + NextOffset: linux.SizeOfIP6TEntry, + TargetOffset: linux.SizeOfIP6TEntry, + }, + } + copy(entry.Entry.IPv6.Dst[:], rule.Filter.Dst) + copy(entry.Entry.IPv6.DstMask[:], rule.Filter.DstMask) + copy(entry.Entry.IPv6.Src[:], rule.Filter.Src) + copy(entry.Entry.IPv6.SrcMask[:], rule.Filter.SrcMask) + copy(entry.Entry.IPv6.OutputInterface[:], rule.Filter.OutputInterface) + copy(entry.Entry.IPv6.OutputInterfaceMask[:], rule.Filter.OutputInterfaceMask) + if rule.Filter.DstInvert { + entry.Entry.IPv6.InverseFlags |= linux.IP6T_INV_DSTIP + } + if rule.Filter.SrcInvert { + entry.Entry.IPv6.InverseFlags |= linux.IP6T_INV_SRCIP + } + if rule.Filter.OutputInterfaceInvert { + entry.Entry.IPv6.InverseFlags |= linux.IP6T_INV_VIA_OUT + } + if rule.Filter.CheckProtocol { + entry.Entry.IPv6.Flags |= linux.IP6T_F_PROTO + } + + for _, matcher := range rule.Matchers { + // Serialize the matcher and add it to the + // entry. + serialized := marshalMatcher(matcher) + nflog("convert to binary: matcher serialized as: %v", serialized) + if len(serialized)%8 != 0 { + panic(fmt.Sprintf("matcher %T is not 64-bit aligned", matcher)) + } + entry.Elems = append(entry.Elems, serialized...) + entry.Entry.NextOffset += uint16(len(serialized)) + entry.Entry.TargetOffset += uint16(len(serialized)) + } + + // Serialize and append the target. + serialized := marshalTarget(rule.Target) + if len(serialized)%8 != 0 { + panic(fmt.Sprintf("target %T is not 64-bit aligned", rule.Target)) + } + entry.Elems = append(entry.Elems, serialized...) + entry.Entry.NextOffset += uint16(len(serialized)) + + nflog("convert to binary: adding entry: %+v", entry) + + entries.Size += uint32(entry.Entry.NextOffset) + entries.Entrytable = append(entries.Entrytable, entry) + info.NumEntries++ + } + + info.Size = entries.Size + nflog("convert to binary: finished with an marshalled size of %d", info.Size) + return entries, info +} + +func modifyEntries6(stk *stack.Stack, optVal []byte, replace *linux.IPTReplace, table *stack.Table) (map[uint32]int, *syserr.Error) { + nflog("set entries: setting entries in table %q", replace.Name.String()) + + // Convert input into a list of rules and their offsets. + var offset uint32 + // offsets maps rule byte offsets to their position in table.Rules. + offsets := map[uint32]int{} + for entryIdx := uint32(0); entryIdx < replace.NumEntries; entryIdx++ { + nflog("set entries: processing entry at offset %d", offset) + + // Get the struct ipt_entry. + if len(optVal) < linux.SizeOfIP6TEntry { + nflog("optVal has insufficient size for entry %d", len(optVal)) + return nil, syserr.ErrInvalidArgument + } + var entry linux.IP6TEntry + buf := optVal[:linux.SizeOfIP6TEntry] + binary.Unmarshal(buf, usermem.ByteOrder, &entry) + initialOptValLen := len(optVal) + optVal = optVal[linux.SizeOfIP6TEntry:] + + if entry.TargetOffset < linux.SizeOfIP6TEntry { + nflog("entry has too-small target offset %d", entry.TargetOffset) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): We should support more IPTIP + // filtering fields. + filter, err := filterFromIP6TIP(entry.IPv6) + if err != nil { + nflog("bad iptip: %v", err) + return nil, syserr.ErrInvalidArgument + } + + // TODO(gvisor.dev/issue/170): Matchers and targets can specify + // that they only work for certain protocols, hooks, tables. + // Get matchers. + matchersSize := entry.TargetOffset - linux.SizeOfIP6TEntry + if len(optVal) < int(matchersSize) { + nflog("entry doesn't have enough room for its matchers (only %d bytes remain)", len(optVal)) + return nil, syserr.ErrInvalidArgument + } + matchers, err := parseMatchers(filter, optVal[:matchersSize]) + if err != nil { + nflog("failed to parse matchers: %v", err) + return nil, syserr.ErrInvalidArgument + } + optVal = optVal[matchersSize:] + + // Get the target of the rule. + targetSize := entry.NextOffset - entry.TargetOffset + if len(optVal) < int(targetSize) { + nflog("entry doesn't have enough room for its target (only %d bytes remain)", len(optVal)) + return nil, syserr.ErrInvalidArgument + } + target, err := parseTarget(filter, optVal[:targetSize]) + if err != nil { + nflog("failed to parse target: %v", err) + return nil, syserr.ErrInvalidArgument + } + optVal = optVal[targetSize:] + + table.Rules = append(table.Rules, stack.Rule{ + Filter: filter, + Target: target, + Matchers: matchers, + }) + offsets[offset] = int(entryIdx) + offset += uint32(entry.NextOffset) + + if initialOptValLen-len(optVal) != int(entry.NextOffset) { + nflog("entry NextOffset is %d, but entry took up %d bytes", entry.NextOffset, initialOptValLen-len(optVal)) + return nil, syserr.ErrInvalidArgument + } + } + return offsets, nil +} + +func filterFromIP6TIP(iptip linux.IP6TIP) (stack.IPHeaderFilter, error) { + if containsUnsupportedFields6(iptip) { + return stack.IPHeaderFilter{}, fmt.Errorf("unsupported fields in struct iptip: %+v", iptip) + } + if len(iptip.Dst) != header.IPv6AddressSize || len(iptip.DstMask) != header.IPv6AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of destination (%d) and/or destination mask (%d) fields", len(iptip.Dst), len(iptip.DstMask)) + } + if len(iptip.Src) != header.IPv6AddressSize || len(iptip.SrcMask) != header.IPv6AddressSize { + return stack.IPHeaderFilter{}, fmt.Errorf("incorrect length of source (%d) and/or source mask (%d) fields", len(iptip.Src), len(iptip.SrcMask)) + } + + n := bytes.IndexByte([]byte(iptip.OutputInterface[:]), 0) + if n == -1 { + n = len(iptip.OutputInterface) + } + ifname := string(iptip.OutputInterface[:n]) + + n = bytes.IndexByte([]byte(iptip.OutputInterfaceMask[:]), 0) + if n == -1 { + n = len(iptip.OutputInterfaceMask) + } + ifnameMask := string(iptip.OutputInterfaceMask[:n]) + + return stack.IPHeaderFilter{ + Protocol: tcpip.TransportProtocolNumber(iptip.Protocol), + // In ip6tables a flag controls whether to check the protocol. + CheckProtocol: iptip.Flags&linux.IP6T_F_PROTO != 0, + Dst: tcpip.Address(iptip.Dst[:]), + DstMask: tcpip.Address(iptip.DstMask[:]), + DstInvert: iptip.InverseFlags&linux.IP6T_INV_DSTIP != 0, + Src: tcpip.Address(iptip.Src[:]), + SrcMask: tcpip.Address(iptip.SrcMask[:]), + SrcInvert: iptip.InverseFlags&linux.IP6T_INV_SRCIP != 0, + OutputInterface: ifname, + OutputInterfaceMask: ifnameMask, + OutputInterfaceInvert: iptip.InverseFlags&linux.IP6T_INV_VIA_OUT != 0, + }, nil +} + +func containsUnsupportedFields6(iptip linux.IP6TIP) bool { + // The following features are supported: + // - Protocol + // - Dst and DstMask + // - Src and SrcMask + // - The inverse destination IP check flag + // - OutputInterface, OutputInterfaceMask and its inverse. + var emptyInterface = [linux.IFNAMSIZ]byte{} + flagMask := uint8(linux.IP6T_F_PROTO) + // Disable any supported inverse flags. + inverseMask := uint8(linux.IP6T_INV_DSTIP) | uint8(linux.IP6T_INV_SRCIP) | uint8(linux.IP6T_INV_VIA_OUT) + return iptip.InputInterface != emptyInterface || + iptip.InputInterfaceMask != emptyInterface || + iptip.Flags&^flagMask != 0 || + iptip.InverseFlags&^inverseMask != 0 || + iptip.TOS != 0 +} diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index df256676f..3e1735079 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -42,14 +42,19 @@ func nflog(format string, args ...interface{}) { } // GetInfo returns information about iptables. -func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPTGetinfo, *syserr.Error) { +func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) { // Read in the struct and table name. var info linux.IPTGetinfo if _, err := info.CopyIn(t, outPtr); err != nil { return linux.IPTGetinfo{}, syserr.FromError(err) } - _, info, err := convertNetstackToBinary(stack, info.Name) + var err error + if ipv6 { + _, info, err = convertNetstackToBinary6(stack, info.Name) + } else { + _, info, err = convertNetstackToBinary4(stack, info.Name) + } if err != nil { nflog("couldn't convert iptables: %v", err) return linux.IPTGetinfo{}, syserr.ErrInvalidArgument @@ -59,9 +64,9 @@ func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr) (linux.IPT return info, nil } -// GetEntries4 returns netstack's iptables rules encoded for the iptables tool. +// GetEntries4 returns netstack's iptables rules. func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIPTGetEntries, *syserr.Error) { - // Read in the ABI struct. + // Read in the struct and table name. var userEntries linux.IPTGetEntries if _, err := userEntries.CopyIn(t, outPtr); err != nil { nflog("couldn't copy in entries %q", userEntries.Name) @@ -70,7 +75,7 @@ func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen // Convert netstack's iptables rules to something that the iptables // tool can understand. - entries, _, err := convertNetstackToBinary(stack, userEntries.Name) + entries, _, err := convertNetstackToBinary4(stack, userEntries.Name) if err != nil { nflog("couldn't read entries: %v", err) return linux.KernelIPTGetEntries{}, syserr.ErrInvalidArgument @@ -83,28 +88,29 @@ func GetEntries4(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen return entries, nil } -// convertNetstackToBinary converts the iptables as stored in netstack to the -// format expected by the iptables tool. Linux stores each table as a binary -// blob that can only be traversed by parsing a bit, reading some offsets, -// jumping to those offsets, parsing again, etc. -func convertNetstackToBinary(stk *stack.Stack, tablename linux.TableName) (linux.KernelIPTGetEntries, linux.IPTGetinfo, error) { - // The table name has to fit in the struct. - if linux.XT_TABLE_MAXNAMELEN < len(tablename) { - return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) +// GetEntries6 returns netstack's ip6tables rules. +func GetEntries6(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, outLen int) (linux.KernelIP6TGetEntries, *syserr.Error) { + // Read in the struct and table name. IPv4 and IPv6 utilize structs + // with the same layout. + var userEntries linux.IPTGetEntries + if _, err := userEntries.CopyIn(t, outPtr); err != nil { + nflog("couldn't copy in entries %q", userEntries.Name) + return linux.KernelIP6TGetEntries{}, syserr.FromError(err) } - table, ok := stk.IPTables().GetTable(tablename.String()) - if !ok { - return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) + // Convert netstack's iptables rules to something that the iptables + // tool can understand. + entries, _, err := convertNetstackToBinary6(stack, userEntries.Name) + if err != nil { + nflog("couldn't read entries: %v", err) + return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument + } + if binary.Size(entries) > uintptr(outLen) { + nflog("insufficient GetEntries output size: %d", uintptr(outLen)) + return linux.KernelIP6TGetEntries{}, syserr.ErrInvalidArgument } - // Setup the info struct. - var info linux.IPTGetinfo - info.ValidHooks = table.ValidHooks() - copy(info.Name[:], tablename[:]) - - entries := getEntries4(table, &info) - return entries, info, nil + return entries, nil } // setHooksAndUnderflow checks whether the rule at ruleIdx is a hook entrypoint @@ -128,7 +134,7 @@ func setHooksAndUnderflow(info *linux.IPTGetinfo, table stack.Table, offset uint // SetEntries sets iptables rules for a single table. See // net/ipv4/netfilter/ip_tables.c:translate_table for reference. -func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { +func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { var replace linux.IPTReplace replaceBuf := optVal[:linux.SizeOfIPTReplace] optVal = optVal[linux.SizeOfIPTReplace:] @@ -146,7 +152,13 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { return syserr.ErrInvalidArgument } - offsets, err := modifyEntries4(stk, optVal, &replace, &table) + var err *syserr.Error + var offsets map[uint32]int + if ipv6 { + offsets, err = modifyEntries6(stk, optVal, &replace, &table) + } else { + offsets, err = modifyEntries4(stk, optVal, &replace, &table) + } if err != nil { return err } @@ -163,7 +175,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { table.BuiltinChains[hk] = ruleIdx } if offset == replace.Underflow[hook] { - if !validUnderflow(table.Rules[ruleIdx]) { + if !validUnderflow(table.Rules[ruleIdx], ipv6) { nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx) return syserr.ErrInvalidArgument } @@ -228,7 +240,7 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { if ruleIdx == stack.HookUnset { continue } - if !isUnconditionalAccept(table.Rules[ruleIdx]) { + if !isUnconditionalAccept(table.Rules[ruleIdx], ipv6) { nflog("hook %d is unsupported.", hook) return syserr.ErrInvalidArgument } @@ -240,7 +252,8 @@ func SetEntries(stk *stack.Stack, optVal []byte) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table)) + return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table, ipv6)) + } // parseMatchers parses 0 or more matchers from optVal. optVal should contain @@ -286,11 +299,11 @@ func parseMatchers(filter stack.IPHeaderFilter, optVal []byte) ([]stack.Matcher, return matchers, nil } -func validUnderflow(rule stack.Rule) bool { +func validUnderflow(rule stack.Rule, ipv6 bool) bool { if len(rule.Matchers) != 0 { return false } - if rule.Filter != emptyIPv4Filter { + if (ipv6 && rule.Filter != emptyIPv6Filter) || (!ipv6 && rule.Filter != emptyIPv4Filter) { return false } switch rule.Target.(type) { @@ -301,8 +314,8 @@ func validUnderflow(rule stack.Rule) bool { } } -func isUnconditionalAccept(rule stack.Rule) bool { - if !validUnderflow(rule) { +func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool { + if !validUnderflow(rule, ipv6) { return false } _, ok := rule.Target.(stack.AcceptTarget) diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 8ebdaff18..87e41abd8 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -218,8 +218,8 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, erro return nil, fmt.Errorf("netfilter.SetEntries: optVal has insufficient size for redirect target %d", len(optVal)) } - if filter.Protocol != header.TCPProtocolNumber && filter.Protocol != header.UDPProtocolNumber { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + if p := filter.Protocol; p != header.TCPProtocolNumber && p != header.UDPProtocolNumber { + return nil, fmt.Errorf("netfilter.SetEntries: bad proto %d", p) } var redirectTarget linux.XTRedirectTarget @@ -232,7 +232,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, erro // RangeSize should be 1. if nfRange.RangeSize != 1 { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + return nil, fmt.Errorf("netfilter.SetEntries: bad rangesize %d", nfRange.RangeSize) } // TODO(gvisor.dev/issue/170): Check if the flags are valid. @@ -240,7 +240,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, erro // For now, redirect target only supports destination port change. // Port range and IP range are not supported yet. if nfRange.RangeIPV4.Flags&linux.NF_NAT_RANGE_PROTO_SPECIFIED == 0 { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + return nil, fmt.Errorf("netfilter.SetEntries: invalid range flags %d", nfRange.RangeIPV4.Flags) } target.RangeProtoSpecified = true @@ -249,7 +249,7 @@ func parseTarget(filter stack.IPHeaderFilter, optVal []byte) (stack.Target, erro // TODO(gvisor.dev/issue/170): Port range is not supported yet. if nfRange.RangeIPV4.MinPort != nfRange.RangeIPV4.MaxPort { - return nil, fmt.Errorf("netfilter.SetEntries: invalid argument") + return nil, fmt.Errorf("netfilter.SetEntries: minport != maxport (%d, %d)", nfRange.RangeIPV4.MinPort, nfRange.RangeIPV4.MaxPort) } // Convert port from big endian to little endian. diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 9e2ebc7d4..8da77cc68 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -257,6 +257,9 @@ type commonEndpoint interface { // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt and // transport.Endpoint.GetSockOpt. GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) + + // LastError implements tcpip.Endpoint.LastError. + LastError() *tcpip.Error } // LINT.IfChange @@ -997,7 +1000,7 @@ func GetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, family in return getSockOptTCP(t, ep, name, outLen) case linux.SOL_IPV6: - return getSockOptIPv6(t, ep, name, outLen) + return getSockOptIPv6(t, s, ep, name, outPtr, outLen) case linux.SOL_IP: return getSockOptIP(t, s, ep, name, outPtr, outLen, family) @@ -1030,7 +1033,7 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam } // Get the last error and convert it. - err := ep.GetSockOpt(tcpip.ErrorOption{}) + err := ep.LastError() if err == nil { optP := primitive.Int32(0) return &optP, nil @@ -1455,7 +1458,7 @@ func getSockOptTCP(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal } // getSockOptIPv6 implements GetSockOpt when level is SOL_IPV6. -func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marshal.Marshallable, *syserr.Error) { +func getSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, outPtr usermem.Addr, outLen int) (marshal.Marshallable, *syserr.Error) { switch name { case linux.IPV6_V6ONLY: if outLen < sizeOfInt32 { @@ -1508,10 +1511,50 @@ func getSockOptIPv6(t *kernel.Task, ep commonEndpoint, name, outLen int) (marsha vP := primitive.Int32(boolToInt32(v)) return &vP, nil - case linux.SO_ORIGINAL_DST: + case linux.IP6T_ORIGINAL_DST: // TODO(gvisor.dev/issue/170): ip6tables. return nil, syserr.ErrInvalidArgument + case linux.IP6T_SO_GET_INFO: + if outLen < linux.SizeOfIPTGetinfo { + return nil, syserr.ErrInvalidArgument + } + + // Only valid for raw IPv6 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, true) + if err != nil { + return nil, err + } + return &info, nil + + case linux.IP6T_SO_GET_ENTRIES: + // IPTGetEntries is reused for IPv6. + if outLen < linux.SizeOfIPTGetEntries { + return nil, syserr.ErrInvalidArgument + } + // Only valid for raw IPv6 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + return nil, syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return nil, syserr.ErrNoDevice + } + entries, err := netfilter.GetEntries6(t, stack.(*Stack).Stack, outPtr, outLen) + if err != nil { + return nil, err + } + return &entries, nil + default: emitUnimplementedEventIPv6(t, name) } @@ -1649,7 +1692,7 @@ func getSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in if stack == nil { return nil, syserr.ErrNoDevice } - info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr) + info, err := netfilter.GetInfo(t, stack.(*Stack).Stack, outPtr, false) if err != nil { return nil, err } @@ -1722,7 +1765,7 @@ func SetSockOpt(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, level int return setSockOptTCP(t, ep, name, optVal) case linux.SOL_IPV6: - return setSockOptIPv6(t, ep, name, optVal) + return setSockOptIPv6(t, s, ep, name, optVal) case linux.SOL_IP: return setSockOptIP(t, s, ep, name, optVal) @@ -2027,7 +2070,7 @@ func setSockOptTCP(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) * } // setSockOptIPv6 implements SetSockOpt when level is SOL_IPV6. -func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) *syserr.Error { +func setSockOptIPv6(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name int, optVal []byte) *syserr.Error { switch name { case linux.IPV6_V6ONLY: if len(optVal) < sizeOfInt32 { @@ -2076,6 +2119,27 @@ func setSockOptIPv6(t *kernel.Task, ep commonEndpoint, name int, optVal []byte) return syserr.TranslateNetstackError(ep.SetSockOptBool(tcpip.ReceiveTClassOption, v != 0)) + case linux.IP6T_SO_SET_REPLACE: + if len(optVal) < linux.SizeOfIP6TReplace { + return syserr.ErrInvalidArgument + } + + // Only valid for raw IPv6 sockets. + if family, skType, _ := s.Type(); family != linux.AF_INET6 || skType != linux.SOCK_RAW { + return syserr.ErrProtocolNotAvailable + } + + stack := inet.StackFromContext(t) + if stack == nil { + return syserr.ErrNoDevice + } + // Stack must be a netstack stack. + return netfilter.SetEntries(stack.(*Stack).Stack, optVal, true) + + case linux.IP6T_SO_SET_ADD_COUNTERS: + // TODO(gvisor.dev/issue/170): Counter support. + return nil + default: emitUnimplementedEventIPv6(t, name) } @@ -2271,7 +2335,7 @@ func setSockOptIP(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, name in return syserr.ErrNoDevice } // Stack must be a netstack stack. - return netfilter.SetEntries(stack.(*Stack).Stack, optVal) + return netfilter.SetEntries(stack.(*Stack).Stack, optVal, false) case linux.IPT_SO_SET_ADD_COUNTERS: // TODO(gvisor.dev/issue/170): Counter support. diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index ab7bab5cd..4bf06d4dc 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -199,6 +199,9 @@ type Endpoint interface { // State returns the current state of the socket, as represented by Linux in // procfs. State() uint32 + + // LastError implements tcpip.Endpoint.LastError. + LastError() *tcpip.Error } // A Credentialer is a socket or endpoint that supports the SO_PASSCRED socket @@ -942,7 +945,7 @@ func (e *baseEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch opt.(type) { - case tcpip.ErrorOption, *tcpip.LingerOption: + case *tcpip.LingerOption: return nil default: @@ -951,6 +954,11 @@ func (e *baseEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { } } +// LastError implements Endpoint.LastError. +func (*baseEndpoint) LastError() *tcpip.Error { + return nil +} + // Shutdown closes the read and/or write end of the endpoint connection to its // peer. func (e *baseEndpoint) Shutdown(flags tcpip.ShutdownFlags) *syserr.Error { diff --git a/pkg/sentry/watchdog/watchdog.go b/pkg/sentry/watchdog/watchdog.go index 748273366..bbafb8b7f 100644 --- a/pkg/sentry/watchdog/watchdog.go +++ b/pkg/sentry/watchdog/watchdog.go @@ -96,15 +96,33 @@ const ( Panic ) +// Set implements flag.Value. +func (a *Action) Set(v string) error { + switch v { + case "log", "logwarning": + *a = LogWarning + case "panic": + *a = Panic + default: + return fmt.Errorf("invalid watchdog action %q", v) + } + return nil +} + +// Get implements flag.Value. +func (a *Action) Get() interface{} { + return *a +} + // String returns Action's string representation. -func (a Action) String() string { - switch a { +func (a *Action) String() string { + switch *a { case LogWarning: - return "LogWarning" + return "logWarning" case Panic: - return "Panic" + return "panic" default: - panic(fmt.Sprintf("Invalid action: %d", a)) + panic(fmt.Sprintf("Invalid watchdog action: %d", *a)) } } diff --git a/pkg/tcpip/adapters/gonet/gonet.go b/pkg/tcpip/adapters/gonet/gonet.go index d82ed5205..68a954a10 100644 --- a/pkg/tcpip/adapters/gonet/gonet.go +++ b/pkg/tcpip/adapters/gonet/gonet.go @@ -541,7 +541,7 @@ func DialContextTCP(ctx context.Context, s *stack.Stack, addr tcpip.FullAddress, case <-notifyCh: } - err = ep.GetSockOpt(tcpip.ErrorOption{}) + err = ep.LastError() } if err != nil { ep.Close() diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 3c552988a..c975ad9cf 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -104,7 +104,7 @@ func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Er err = ep.Connect(addr) if err == tcpip.ErrConnectStarted { <-ch - err = ep.GetSockOpt(tcpip.ErrorOption{}) + err = ep.LastError() } if err != nil { return nil, err diff --git a/pkg/tcpip/sample/tun_tcp_connect/main.go b/pkg/tcpip/sample/tun_tcp_connect/main.go index 0ab089208..91fc26722 100644 --- a/pkg/tcpip/sample/tun_tcp_connect/main.go +++ b/pkg/tcpip/sample/tun_tcp_connect/main.go @@ -182,7 +182,7 @@ func main() { if terr == tcpip.ErrConnectStarted { fmt.Println("Connect is pending...") <-notifyCh - terr = ep.GetSockOpt(tcpip.ErrorOption{}) + terr = ep.LastError() } wq.EventUnregister(&waitEntry) diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 41ef4236b..30aa41db2 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -165,7 +165,11 @@ func EmptyNATTable() Table { } // GetTable returns a table by name. -func (it *IPTables) GetTable(name string) (Table, bool) { +func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) { + // TODO(gvisor.dev/issue/3549): Enable IPv6. + if ipv6 { + return Table{}, false + } id, ok := nameToID[name] if !ok { return Table{}, false @@ -176,7 +180,11 @@ func (it *IPTables) GetTable(name string) (Table, bool) { } // ReplaceTable replaces or inserts table by name. -func (it *IPTables) ReplaceTable(name string, table Table) *tcpip.Error { +func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error { + // TODO(gvisor.dev/issue/3549): Enable IPv6. + if ipv6 { + return tcpip.ErrInvalidOptionValue + } id, ok := nameToID[name] if !ok { return tcpip.ErrInvalidOptionValue diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 73274ada9..fbbd2f50f 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -155,6 +155,11 @@ type IPHeaderFilter struct { // Protocol matches the transport protocol. Protocol tcpip.TransportProtocolNumber + // CheckProtocol determines whether the Protocol field should be + // checked during matching. + // TODO(gvisor.dev/issue/3549): Check this field during matching. + CheckProtocol bool + // Dst matches the destination IP address. Dst tcpip.Address diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 6c6e44468..7869bb98b 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -53,11 +53,11 @@ func (f *fakeTransportEndpoint) Info() tcpip.EndpointInfo { return &f.TransportEndpointInfo } -func (f *fakeTransportEndpoint) Stats() tcpip.EndpointStats { +func (*fakeTransportEndpoint) Stats() tcpip.EndpointStats { return nil } -func (f *fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} +func (*fakeTransportEndpoint) SetOwner(owner tcpip.PacketOwner) {} func newFakeTransportEndpoint(s *stack.Stack, proto *fakeTransportProtocol, netProto tcpip.NetworkProtocolNumber, uniqueID uint64) tcpip.Endpoint { return &fakeTransportEndpoint{stack: s, TransportEndpointInfo: stack.TransportEndpointInfo{NetProto: netProto}, proto: proto, uniqueID: uniqueID} @@ -100,7 +100,7 @@ func (f *fakeTransportEndpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions return int64(len(v)), nil, nil } -func (f *fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { +func (*fakeTransportEndpoint) Peek([][]byte) (int64, tcpip.ControlMessages, *tcpip.Error) { return 0, tcpip.ControlMessages{}, nil } @@ -131,10 +131,6 @@ func (*fakeTransportEndpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.E // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (*fakeTransportEndpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: - return nil - } return tcpip.ErrInvalidEndpointState } @@ -169,7 +165,7 @@ func (f *fakeTransportEndpoint) UniqueID() uint64 { return f.uniqueID } -func (f *fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { +func (*fakeTransportEndpoint) ConnectEndpoint(e tcpip.Endpoint) *tcpip.Error { return nil } @@ -239,19 +235,19 @@ func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, s f.proto.controlCount++ } -func (f *fakeTransportEndpoint) State() uint32 { +func (*fakeTransportEndpoint) State() uint32 { return 0 } -func (f *fakeTransportEndpoint) ModerateRecvBuf(copied int) {} +func (*fakeTransportEndpoint) ModerateRecvBuf(copied int) {} -func (f *fakeTransportEndpoint) IPTables() (stack.IPTables, error) { - return stack.IPTables{}, nil -} +func (*fakeTransportEndpoint) Resume(*stack.Stack) {} -func (f *fakeTransportEndpoint) Resume(*stack.Stack) {} +func (*fakeTransportEndpoint) Wait() {} -func (f *fakeTransportEndpoint) Wait() {} +func (*fakeTransportEndpoint) LastError() *tcpip.Error { + return nil +} type fakeTransportGoodOption bool diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index 609b8af33..cae943608 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -620,6 +620,9 @@ type Endpoint interface { // SetOwner sets the task owner to the endpoint owner. SetOwner(owner PacketOwner) + + // LastError clears and returns the last error reported by the endpoint. + LastError() *Error } // LinkPacketInfo holds Link layer information for a received packet. @@ -839,10 +842,6 @@ const ( PMTUDiscoveryProbe ) -// ErrorOption is used in GetSockOpt to specify that the last error reported by -// the endpoint should be cleared and returned. -type ErrorOption struct{} - // BindToDeviceOption is used by SetSockOpt/GetSockOpt to specify that sockets // should bind only on a specific NIC. type BindToDeviceOption NICID diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index bd6f49eb8..c545c8367 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -415,14 +415,8 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } +func (*endpoint) GetSockOpt(interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption } func send4(r *stack.Route, ident uint16, data buffer.View, ttl uint8, owner tcpip.PacketOwner) *tcpip.Error { @@ -836,3 +830,8 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements stack.TransportEndpoint.Wait. func (*endpoint) Wait() {} + +// LastError implements tcpip.Endpoint.LastError. +func (*endpoint) LastError() *tcpip.Error { + return nil +} diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 1b03ad6bb..95dc8ed57 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -356,7 +356,7 @@ func (ep *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } } -func (ep *endpoint) takeLastError() *tcpip.Error { +func (ep *endpoint) LastError() *tcpip.Error { ep.lastErrorMu.Lock() defer ep.lastErrorMu.Unlock() @@ -366,11 +366,7 @@ func (ep *endpoint) takeLastError() *tcpip.Error { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (ep *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: - return ep.takeLastError() - } +func (*endpoint) GetSockOpt(interface{}) *tcpip.Error { return tcpip.ErrNotSupported } diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index edc2b5b61..2087bcfa8 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -577,14 +577,8 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { } // GetSockOpt implements tcpip.Endpoint.GetSockOpt. -func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { - switch opt.(type) { - case tcpip.ErrorOption: - return nil - - default: - return tcpip.ErrUnknownProtocolOption - } +func (*endpoint) GetSockOpt(interface{}) *tcpip.Error { + return tcpip.ErrUnknownProtocolOption } // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. @@ -739,3 +733,7 @@ func (e *endpoint) Stats() tcpip.EndpointStats { // Wait implements stack.TransportEndpoint.Wait. func (*endpoint) Wait() {} + +func (*endpoint) LastError() *tcpip.Error { + return nil +} diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 290172ac9..72df5c2a1 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -491,7 +491,7 @@ func (h *handshake) resolveRoute() *tcpip.Error { h.ep.mu.Lock() } if n¬ifyError != 0 { - return h.ep.takeLastError() + return h.ep.LastError() } } @@ -620,7 +620,7 @@ func (h *handshake) execute() *tcpip.Error { h.ep.mu.Lock() } if n¬ifyError != 0 { - return h.ep.takeLastError() + return h.ep.LastError() } case wakerForNewSegment: diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 804e95aea..6074cc24e 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -86,8 +86,7 @@ func testV4Connect(t *testing.T, c *context.Context, checkers ...checker.Network // Wait for connection to be established. select { case <-ch: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { + if err := c.EP.LastError(); err != nil { t.Fatalf("Unexpected error when connecting: %v", err) } case <-time.After(1 * time.Second): @@ -194,8 +193,7 @@ func testV6Connect(t *testing.T, c *context.Context, checkers ...checker.Network // Wait for connection to be established. select { case <-ch: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { + if err := c.EP.LastError(); err != nil { t.Fatalf("Unexpected error when connecting: %v", err) } case <-time.After(1 * time.Second): diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index ff9b8804d..8a5e993b5 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -1234,7 +1234,7 @@ func (e *endpoint) SetOwner(owner tcpip.PacketOwner) { e.owner = owner } -func (e *endpoint) takeLastError() *tcpip.Error { +func (e *endpoint) LastError() *tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() err := e.lastError @@ -1995,9 +1995,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { - case tcpip.ErrorOption: - return e.takeLastError() - case *tcpip.BindToDeviceOption: e.LockUser() *o = tcpip.BindToDeviceOption(e.bindToDevice) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index 9650bb06c..3d3034d50 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -74,8 +74,8 @@ func TestGiveUpConnect(t *testing.T) { // Wait for ep to become writable. <-notifyCh - if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != tcpip.ErrAborted { - t.Fatalf("got ep.GetSockOpt(tcpip.ErrorOption{}) = %s, want = %s", err, tcpip.ErrAborted) + if err := ep.LastError(); err != tcpip.ErrAborted { + t.Fatalf("got ep.LastError() = %s, want = %s", err, tcpip.ErrAborted) } // Call Connect again to retreive the handshake failure status @@ -3023,8 +3023,8 @@ func TestSynOptionsOnActiveConnect(t *testing.T) { // Wait for connection to be established. select { case <-ch: - if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { - t.Fatalf("GetSockOpt failed: %s", err) + if err := c.EP.LastError(); err != nil { + t.Fatalf("Connect failed: %s", err) } case <-time.After(1 * time.Second): t.Fatalf("Timed out waiting for connection") @@ -4411,7 +4411,7 @@ func TestSelfConnect(t *testing.T) { } <-notifyCh - if err := ep.GetSockOpt(tcpip.ErrorOption{}); err != nil { + if err := ep.LastError(); err != nil { t.Fatalf("Connect failed: %s", err) } diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index b6031354e..1f5340cd0 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -638,7 +638,7 @@ func (c *Context) Connect(iss seqnum.Value, rcvWnd seqnum.Size, options []byte) // Wait for connection to be established. select { case <-notifyCh: - if err := c.EP.GetSockOpt(tcpip.ErrorOption{}); err != nil { + if err := c.EP.LastError(); err != nil { c.t.Fatalf("Unexpected error when connecting: %v", err) } case <-time.After(1 * time.Second): @@ -882,8 +882,7 @@ func (c *Context) CreateConnectedWithOptions(wantOptions header.TCPSynOptions) * // Wait for connection to be established. select { case <-notifyCh: - err = c.EP.GetSockOpt(tcpip.ErrorOption{}) - if err != nil { + if err := c.EP.LastError(); err != nil { c.t.Fatalf("Unexpected error when connecting: %v", err) } case <-time.After(1 * time.Second): diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index 0a9d3c6cf..1d5ebe3f2 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -209,7 +209,7 @@ func (e *endpoint) UniqueID() uint64 { return e.uniqueID } -func (e *endpoint) takeLastError() *tcpip.Error { +func (e *endpoint) LastError() *tcpip.Error { e.lastErrorMu.Lock() defer e.lastErrorMu.Unlock() @@ -268,7 +268,7 @@ func (e *endpoint) ModerateRecvBuf(copied int) {} // Read reads data from the endpoint. This method does not block if // there is no data pending. func (e *endpoint) Read(addr *tcpip.FullAddress) (buffer.View, tcpip.ControlMessages, *tcpip.Error) { - if err := e.takeLastError(); err != nil { + if err := e.LastError(); err != nil { return buffer.View{}, tcpip.ControlMessages{}, err } @@ -411,7 +411,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c } func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-chan struct{}, *tcpip.Error) { - if err := e.takeLastError(); err != nil { + if err := e.LastError(); err != nil { return 0, nil, err } @@ -962,8 +962,6 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { // GetSockOpt implements tcpip.Endpoint.GetSockOpt. func (e *endpoint) GetSockOpt(opt interface{}) *tcpip.Error { switch o := opt.(type) { - case tcpip.ErrorOption: - return e.takeLastError() case *tcpip.MulticastInterfaceOption: e.mu.Lock() *o = tcpip.MulticastInterfaceOption{ diff --git a/pkg/test/testutil/testutil.go b/pkg/test/testutil/testutil.go index 42d79f5c2..b7f873392 100644 --- a/pkg/test/testutil/testutil.go +++ b/pkg/test/testutil/testutil.go @@ -138,20 +138,23 @@ func TestConfig(t *testing.T) *config.Config { if dir, ok := os.LookupEnv("TEST_UNDECLARED_OUTPUTS_DIR"); ok { logDir = dir + "/" } - return &config.Config{ - Debug: true, - DebugLog: path.Join(logDir, "runsc.log."+t.Name()+".%TIMESTAMP%.%COMMAND%"), - LogFormat: "text", - DebugLogFormat: "text", - LogPackets: true, - Network: config.NetworkNone, - Strace: true, - Platform: "ptrace", - FileAccess: config.FileAccessExclusive, - NumNetworkChannels: 1, - - TestOnlyAllowRunAsCurrentUserWithoutChroot: true, - } + + // Only register flags if config is being used. Otherwise anyone that uses + // testutil will get flags registered and they may conflict. + config.RegisterFlags() + + conf, err := config.NewFromFlags() + if err != nil { + panic(err) + } + // Change test defaults. + conf.Debug = true + conf.DebugLog = path.Join(logDir, "runsc.log."+t.Name()+".%TIMESTAMP%.%COMMAND%") + conf.LogPackets = true + conf.Network = config.NetworkNone + conf.Strace = true + conf.TestOnlyAllowRunAsCurrentUserWithoutChroot = true + return conf } // NewSpecWithArgs creates a simple spec with the given args suitable for use diff --git a/runsc/boot/loader_test.go b/runsc/boot/loader_test.go index 03cbaec33..2343ce76c 100644 --- a/runsc/boot/loader_test.go +++ b/runsc/boot/loader_test.go @@ -44,15 +44,19 @@ func init() { if err := fsgofer.OpenProcSelfFD(); err != nil { panic(err) } + config.RegisterFlags() } func testConfig() *config.Config { - return &config.Config{ - RootDir: "unused_root_dir", - Network: config.NetworkNone, - DisableSeccomp: true, - Platform: "ptrace", + conf, err := config.NewFromFlags() + if err != nil { + panic(err) } + // Change test defaults. + conf.RootDir = "unused_root_dir" + conf.Network = config.NetworkNone + conf.DisableSeccomp = true + return conf } // testSpec returns a simple spec that can be used in tests. @@ -546,7 +550,7 @@ func TestRestoreEnvironment(t *testing.T) { { Dev: "9pfs-/", Flags: fs.MountSourceFlags{ReadOnly: true}, - DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true,cache=remote_revalidating", + DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true", }, }, "tmpfs": { @@ -600,7 +604,7 @@ func TestRestoreEnvironment(t *testing.T) { { Dev: "9pfs-/", Flags: fs.MountSourceFlags{ReadOnly: true}, - DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true,cache=remote_revalidating", + DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true", }, { Dev: "9pfs-/dev/fd-foo", @@ -658,7 +662,7 @@ func TestRestoreEnvironment(t *testing.T) { { Dev: "9pfs-/", Flags: fs.MountSourceFlags{ReadOnly: true}, - DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true,cache=remote_revalidating", + DataString: "trans=fd,rfdno=0,wfdno=0,privateunixsocket=true", }, }, "tmpfs": { diff --git a/runsc/boot/strace.go b/runsc/boot/strace.go index 176981f74..c21648a32 100644 --- a/runsc/boot/strace.go +++ b/runsc/boot/strace.go @@ -15,6 +15,8 @@ package boot import ( + "strings" + "gvisor.dev/gvisor/pkg/sentry/strace" "gvisor.dev/gvisor/runsc/config" ) @@ -37,5 +39,5 @@ func enableStrace(conf *config.Config) error { strace.EnableAll(strace.SinkTypeLog) return nil } - return strace.Enable(conf.StraceSyscalls, strace.SinkTypeLog) + return strace.Enable(strings.Split(conf.StraceSyscalls, ","), strace.SinkTypeLog) } diff --git a/runsc/config/BUILD b/runsc/config/BUILD index 3c8713d53..b1672bb9d 100644 --- a/runsc/config/BUILD +++ b/runsc/config/BUILD @@ -1,4 +1,4 @@ -load("//tools:defs.bzl", "go_library") +load("//tools:defs.bzl", "go_library", "go_test") package(licenses = ["notice"]) @@ -6,10 +6,23 @@ go_library( name = "config", srcs = [ "config.go", + "flags.go", ], visibility = ["//:sandbox"], deps = [ "//pkg/refs", "//pkg/sentry/watchdog", + "//pkg/sync", + "//runsc/flag", ], ) + +go_test( + name = "config_test", + size = "small", + srcs = [ + "config_test.go", + ], + library = ":config", + deps = ["//runsc/flag"], +) diff --git a/runsc/config/config.go b/runsc/config/config.go index 8cf0378d5..bca27ebf1 100644 --- a/runsc/config/config.go +++ b/runsc/config/config.go @@ -19,254 +19,105 @@ package config import ( "fmt" - "strconv" - "strings" "gvisor.dev/gvisor/pkg/refs" "gvisor.dev/gvisor/pkg/sentry/watchdog" ) -// FileAccessType tells how the filesystem is accessed. -type FileAccessType int - -const ( - // FileAccessShared sends IO requests to a Gofer process that validates the - // requests and forwards them to the host. - FileAccessShared FileAccessType = iota - - // FileAccessExclusive is the same as FileAccessShared, but enables - // extra caching for improved performance. It should only be used if - // the sandbox has exclusive access to the filesystem. - FileAccessExclusive -) - -// MakeFileAccessType converts type from string. -func MakeFileAccessType(s string) (FileAccessType, error) { - switch s { - case "shared": - return FileAccessShared, nil - case "exclusive": - return FileAccessExclusive, nil - default: - return 0, fmt.Errorf("invalid file access type %q", s) - } -} - -func (f FileAccessType) String() string { - switch f { - case FileAccessShared: - return "shared" - case FileAccessExclusive: - return "exclusive" - default: - return fmt.Sprintf("unknown(%d)", f) - } -} - -// NetworkType tells which network stack to use. -type NetworkType int - -const ( - // NetworkSandbox uses internal network stack, isolated from the host. - NetworkSandbox NetworkType = iota - - // NetworkHost redirects network related syscalls to the host network. - NetworkHost - - // NetworkNone sets up just loopback using netstack. - NetworkNone -) - -// MakeNetworkType converts type from string. -func MakeNetworkType(s string) (NetworkType, error) { - switch s { - case "sandbox": - return NetworkSandbox, nil - case "host": - return NetworkHost, nil - case "none": - return NetworkNone, nil - default: - return 0, fmt.Errorf("invalid network type %q", s) - } -} - -func (n NetworkType) String() string { - switch n { - case NetworkSandbox: - return "sandbox" - case NetworkHost: - return "host" - case NetworkNone: - return "none" - default: - return fmt.Sprintf("unknown(%d)", n) - } -} - -// MakeWatchdogAction converts type from string. -func MakeWatchdogAction(s string) (watchdog.Action, error) { - switch strings.ToLower(s) { - case "log", "logwarning": - return watchdog.LogWarning, nil - case "panic": - return watchdog.Panic, nil - default: - return 0, fmt.Errorf("invalid watchdog action %q", s) - } -} - -// MakeRefsLeakMode converts type from string. -func MakeRefsLeakMode(s string) (refs.LeakMode, error) { - switch strings.ToLower(s) { - case "disabled": - return refs.NoLeakChecking, nil - case "log-names": - return refs.LeaksLogWarning, nil - case "log-traces": - return refs.LeaksLogTraces, nil - default: - return 0, fmt.Errorf("invalid refs leakmode %q", s) - } -} - -func refsLeakModeToString(mode refs.LeakMode) string { - switch mode { - // If not set, default it to disabled. - case refs.UninitializedLeakChecking, refs.NoLeakChecking: - return "disabled" - case refs.LeaksLogWarning: - return "log-names" - case refs.LeaksLogTraces: - return "log-traces" - default: - panic(fmt.Sprintf("Invalid leakmode: %d", mode)) - } -} - -// QueueingDiscipline is used to specify the kind of Queueing Discipline to -// apply for a give FDBasedLink. -type QueueingDiscipline int - -const ( - // QDiscNone disables any queueing for the underlying FD. - QDiscNone QueueingDiscipline = iota - - // QDiscFIFO applies a simple fifo based queue to the underlying - // FD. - QDiscFIFO -) - -// MakeQueueingDiscipline if possible the equivalent QueuingDiscipline for s -// else returns an error. -func MakeQueueingDiscipline(s string) (QueueingDiscipline, error) { - switch s { - case "none": - return QDiscNone, nil - case "fifo": - return QDiscFIFO, nil - default: - return 0, fmt.Errorf("unsupported qdisc specified: %q", s) - } -} - -// String implements fmt.Stringer. -func (q QueueingDiscipline) String() string { - switch q { - case QDiscNone: - return "none" - case QDiscFIFO: - return "fifo" - default: - panic(fmt.Sprintf("Invalid queueing discipline: %d", q)) - } -} - // Config holds configuration that is not part of the runtime spec. +// +// Follow these steps to add a new flag: +// 1. Create a new field in Config. +// 2. Add a field tag with the flag name +// 3. Register a new flag in flags.go, with name and description +// 4. Add any necessary validation into validate() +// 5. If adding an enum, follow the same pattern as FileAccessType +// type Config struct { // RootDir is the runtime root directory. - RootDir string + RootDir string `flag:"root"` // Debug indicates that debug logging should be enabled. - Debug bool + Debug bool `flag:"debug"` // LogFilename is the filename to log to, if not empty. - LogFilename string + LogFilename string `flag:"log"` // LogFormat is the log format. - LogFormat string + LogFormat string `flag:"log-format"` // DebugLog is the path to log debug information to, if not empty. - DebugLog string + DebugLog string `flag:"debug-log"` // PanicLog is the path to log GO's runtime messages, if not empty. - PanicLog string + PanicLog string `flag:"panic-log"` // DebugLogFormat is the log format for debug. - DebugLogFormat string + DebugLogFormat string `flag:"debug-log-format"` // FileAccess indicates how the filesystem is accessed. - FileAccess FileAccessType + FileAccess FileAccessType `flag:"file-access"` // Overlay is whether to wrap the root filesystem in an overlay. - Overlay bool + Overlay bool `flag:"overlay"` // FSGoferHostUDS enables the gofer to mount a host UDS. - FSGoferHostUDS bool + FSGoferHostUDS bool `flag:"fsgofer-host-uds"` // Network indicates what type of network to use. - Network NetworkType + Network NetworkType `flag:"network"` // EnableRaw indicates whether raw sockets should be enabled. Raw // sockets are disabled by stripping CAP_NET_RAW from the list of // capabilities. - EnableRaw bool + EnableRaw bool `flag:"net-raw"` // HardwareGSO indicates that hardware segmentation offload is enabled. - HardwareGSO bool + HardwareGSO bool `flag:"gso"` // SoftwareGSO indicates that software segmentation offload is enabled. - SoftwareGSO bool + SoftwareGSO bool `flag:"software-gso"` // TXChecksumOffload indicates that TX Checksum Offload is enabled. - TXChecksumOffload bool + TXChecksumOffload bool `flag:"tx-checksum-offload"` // RXChecksumOffload indicates that RX Checksum Offload is enabled. - RXChecksumOffload bool + RXChecksumOffload bool `flag:"rx-checksum-offload"` // QDisc indicates the type of queuening discipline to use by default // for non-loopback interfaces. - QDisc QueueingDiscipline + QDisc QueueingDiscipline `flag:"qdisc"` // LogPackets indicates that all network packets should be logged. - LogPackets bool + LogPackets bool `flag:"log-packets"` // Platform is the platform to run on. - Platform string + Platform string `flag:"platform"` // Strace indicates that strace should be enabled. - Strace bool + Strace bool `flag:"strace"` - // StraceSyscalls is the set of syscalls to trace. If StraceEnable is - // true and this list is empty, then all syscalls will be traced. - StraceSyscalls []string + // StraceSyscalls is the set of syscalls to trace (comma-separated values). + // If StraceEnable is true and this string is empty, then all syscalls will + // be traced. + StraceSyscalls string `flag:"strace-syscalls"` // StraceLogSize is the max size of data blobs to display. - StraceLogSize uint + StraceLogSize uint `flag:"strace-log-size"` // DisableSeccomp indicates whether seccomp syscall filters should be // disabled. Pardon the double negation, but default to enabled is important. DisableSeccomp bool // WatchdogAction sets what action the watchdog takes when triggered. - WatchdogAction watchdog.Action + WatchdogAction watchdog.Action `flag:"watchdog-action"` // PanicSignal registers signal handling that panics. Usually set to // SIGUSR2(12) to troubleshoot hangs. -1 disables it. - PanicSignal int + PanicSignal int `flag:"panic-signal"` // ProfileEnable is set to prepare the sandbox to be profiled. - ProfileEnable bool + ProfileEnable bool `flag:"profile"` // RestoreFile is the path to the saved container image RestoreFile string @@ -274,105 +125,209 @@ type Config struct { // NumNetworkChannels controls the number of AF_PACKET sockets that map // to the same underlying network device. This allows netstack to better // scale for high throughput use cases. - NumNetworkChannels int + NumNetworkChannels int `flag:"num-network-channels"` // Rootless allows the sandbox to be started with a user that is not root. // Defense is depth measures are weaker with rootless. Specifically, the // sandbox and Gofer process run as root inside a user namespace with root // mapped to the caller's user. - Rootless bool + Rootless bool `flag:"rootless"` // AlsoLogToStderr allows to send log messages to stderr. - AlsoLogToStderr bool + AlsoLogToStderr bool `flag:"alsologtostderr"` // ReferenceLeakMode sets reference leak check mode - ReferenceLeakMode refs.LeakMode + ReferenceLeak refs.LeakMode `flag:"ref-leak-mode"` // OverlayfsStaleRead instructs the sandbox to assume that the root mount // is on a Linux overlayfs mount, which does not necessarily preserve // coherence between read-only and subsequent writable file descriptors // representing the "same" file. - OverlayfsStaleRead bool + OverlayfsStaleRead bool `flag:"overlayfs-stale-read"` // CPUNumFromQuota sets CPU number count to available CPU quota, using // least integer value greater than or equal to quota. // // E.g. 0.2 CPU quota will result in 1, and 1.9 in 2. - CPUNumFromQuota bool + CPUNumFromQuota bool `flag:"cpu-num-from-quota"` - // Enables VFS2 (not plumbed through yet). - VFS2 bool + // Enables VFS2. + VFS2 bool `flag:"vfs2"` - // Enables FUSE usage (not plumbed through yet). - FUSE bool + // Enables FUSE usage. + FUSE bool `flag:"fuse"` // TestOnlyAllowRunAsCurrentUserWithoutChroot should only be used in // tests. It allows runsc to start the sandbox process as the current // user, and without chrooting the sandbox process. This can be // necessary in test environments that have limited capabilities. - TestOnlyAllowRunAsCurrentUserWithoutChroot bool + TestOnlyAllowRunAsCurrentUserWithoutChroot bool `flag:"TESTONLY-unsafe-nonroot"` // TestOnlyTestNameEnv should only be used in tests. It looks up for the // test name in the container environment variables and adds it to the debug // log file name. This is done to help identify the log with the test when // multiple tests are run in parallel, since there is no way to pass // parameters to the runtime from docker. - TestOnlyTestNameEnv string + TestOnlyTestNameEnv string `flag:"TESTONLY-test-name-env"` +} + +func (c *Config) validate() error { + if c.FileAccess == FileAccessShared && c.Overlay { + return fmt.Errorf("overlay flag is incompatible with shared file access") + } + if c.NumNetworkChannels <= 0 { + return fmt.Errorf("num_network_channels must be > 0, got: %d", c.NumNetworkChannels) + } + return nil } -// ToFlags returns a slice of flags that correspond to the given Config. -func (c *Config) ToFlags() []string { - f := []string{ - "--root=" + c.RootDir, - "--debug=" + strconv.FormatBool(c.Debug), - "--log=" + c.LogFilename, - "--log-format=" + c.LogFormat, - "--debug-log=" + c.DebugLog, - "--panic-log=" + c.PanicLog, - "--debug-log-format=" + c.DebugLogFormat, - "--file-access=" + c.FileAccess.String(), - "--overlay=" + strconv.FormatBool(c.Overlay), - "--fsgofer-host-uds=" + strconv.FormatBool(c.FSGoferHostUDS), - "--network=" + c.Network.String(), - "--log-packets=" + strconv.FormatBool(c.LogPackets), - "--platform=" + c.Platform, - "--strace=" + strconv.FormatBool(c.Strace), - "--strace-syscalls=" + strings.Join(c.StraceSyscalls, ","), - "--strace-log-size=" + strconv.Itoa(int(c.StraceLogSize)), - "--watchdog-action=" + c.WatchdogAction.String(), - "--panic-signal=" + strconv.Itoa(c.PanicSignal), - "--profile=" + strconv.FormatBool(c.ProfileEnable), - "--net-raw=" + strconv.FormatBool(c.EnableRaw), - "--num-network-channels=" + strconv.Itoa(c.NumNetworkChannels), - "--rootless=" + strconv.FormatBool(c.Rootless), - "--alsologtostderr=" + strconv.FormatBool(c.AlsoLogToStderr), - "--ref-leak-mode=" + refsLeakModeToString(c.ReferenceLeakMode), - "--gso=" + strconv.FormatBool(c.HardwareGSO), - "--software-gso=" + strconv.FormatBool(c.SoftwareGSO), - "--rx-checksum-offload=" + strconv.FormatBool(c.RXChecksumOffload), - "--tx-checksum-offload=" + strconv.FormatBool(c.TXChecksumOffload), - "--overlayfs-stale-read=" + strconv.FormatBool(c.OverlayfsStaleRead), - "--qdisc=" + c.QDisc.String(), - "--vfs2=" + strconv.FormatBool(c.VFS2), - "--fuse=" + strconv.FormatBool(c.FUSE), +// FileAccessType tells how the filesystem is accessed. +type FileAccessType int + +const ( + // FileAccessExclusive is the same as FileAccessShared, but enables + // extra caching for improved performance. It should only be used if + // the sandbox has exclusive access to the filesystem. + FileAccessExclusive FileAccessType = iota + + // FileAccessShared sends IO requests to a Gofer process that validates the + // requests and forwards them to the host. + FileAccessShared +) + +func fileAccessTypePtr(v FileAccessType) *FileAccessType { + return &v +} + +// Set implements flag.Value. +func (f *FileAccessType) Set(v string) error { + switch v { + case "shared": + *f = FileAccessShared + case "exclusive": + *f = FileAccessExclusive + default: + return fmt.Errorf("invalid file access type %q", v) } - if c.CPUNumFromQuota { - f = append(f, "--cpu-num-from-quota") + return nil +} + +// Get implements flag.Value. +func (f *FileAccessType) Get() interface{} { + return *f +} + +// String implements flag.Value. +func (f *FileAccessType) String() string { + switch *f { + case FileAccessShared: + return "shared" + case FileAccessExclusive: + return "exclusive" } - if c.VFS2 { - f = append(f, "--vfs2=true") + panic(fmt.Sprintf("Invalid file access type %v", *f)) +} + +// NetworkType tells which network stack to use. +type NetworkType int + +const ( + // NetworkSandbox uses internal network stack, isolated from the host. + NetworkSandbox NetworkType = iota + + // NetworkHost redirects network related syscalls to the host network. + NetworkHost + + // NetworkNone sets up just loopback using netstack. + NetworkNone +) + +func networkTypePtr(v NetworkType) *NetworkType { + return &v +} + +// Set implements flag.Value. +func (n *NetworkType) Set(v string) error { + switch v { + case "sandbox": + *n = NetworkSandbox + case "host": + *n = NetworkHost + case "none": + *n = NetworkNone + default: + return fmt.Errorf("invalid network type %q", v) } - if c.FUSE { - f = append(f, "--fuse=true") + return nil +} + +// Get implements flag.Value. +func (n *NetworkType) Get() interface{} { + return *n +} + +// String implements flag.Value. +func (n *NetworkType) String() string { + switch *n { + case NetworkSandbox: + return "sandbox" + case NetworkHost: + return "host" + case NetworkNone: + return "none" } + panic(fmt.Sprintf("Invalid network type %v", *n)) +} + +// QueueingDiscipline is used to specify the kind of Queueing Discipline to +// apply for a give FDBasedLink. +type QueueingDiscipline int - // Only include these if set since it is never to be used by users. - if c.TestOnlyAllowRunAsCurrentUserWithoutChroot { - f = append(f, "--TESTONLY-unsafe-nonroot=true") +const ( + // QDiscNone disables any queueing for the underlying FD. + QDiscNone QueueingDiscipline = iota + + // QDiscFIFO applies a simple fifo based queue to the underlying FD. + QDiscFIFO +) + +func queueingDisciplinePtr(v QueueingDiscipline) *QueueingDiscipline { + return &v +} + +// Set implements flag.Value. +func (q *QueueingDiscipline) Set(v string) error { + switch v { + case "none": + *q = QDiscNone + case "fifo": + *q = QDiscFIFO + default: + return fmt.Errorf("invalid qdisc %q", v) } - if len(c.TestOnlyTestNameEnv) != 0 { - f = append(f, "--TESTONLY-test-name-env="+c.TestOnlyTestNameEnv) + return nil +} + +// Get implements flag.Value. +func (q *QueueingDiscipline) Get() interface{} { + return *q +} + +// String implements flag.Value. +func (q *QueueingDiscipline) String() string { + switch *q { + case QDiscNone: + return "none" + case QDiscFIFO: + return "fifo" } + panic(fmt.Sprintf("Invalid qdisc %v", *q)) +} + +func leakModePtr(v refs.LeakMode) *refs.LeakMode { + return &v +} - return f +func watchdogActionPtr(v watchdog.Action) *watchdog.Action { + return &v } diff --git a/runsc/config/config_test.go b/runsc/config/config_test.go new file mode 100644 index 000000000..af7867a2a --- /dev/null +++ b/runsc/config/config_test.go @@ -0,0 +1,185 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "strings" + "testing" + + "gvisor.dev/gvisor/runsc/flag" +) + +func init() { + RegisterFlags() +} + +func TestDefault(t *testing.T) { + c, err := NewFromFlags() + if err != nil { + t.Fatal(err) + } + // "--root" is always set to something different than the default. Reset it + // to make it easier to test that default values do not generate flags. + c.RootDir = "" + + // All defaults doesn't require setting flags. + flags := c.ToFlags() + if len(flags) > 0 { + t.Errorf("default flags not set correctly for: %s", flags) + } +} + +func setDefault(name string) { + fl := flag.CommandLine.Lookup(name) + fl.Value.Set(fl.DefValue) +} + +func TestFromFlags(t *testing.T) { + flag.CommandLine.Lookup("root").Value.Set("some-path") + flag.CommandLine.Lookup("debug").Value.Set("true") + flag.CommandLine.Lookup("num-network-channels").Value.Set("123") + flag.CommandLine.Lookup("network").Value.Set("none") + defer func() { + setDefault("root") + setDefault("debug") + setDefault("num-network-channels") + setDefault("network") + }() + + c, err := NewFromFlags() + if err != nil { + t.Fatal(err) + } + if want := "some-path"; c.RootDir != want { + t.Errorf("RootDir=%v, want: %v", c.RootDir, want) + } + if want := true; c.Debug != want { + t.Errorf("Debug=%v, want: %v", c.Debug, want) + } + if want := 123; c.NumNetworkChannels != want { + t.Errorf("NumNetworkChannels=%v, want: %v", c.NumNetworkChannels, want) + } + if want := NetworkNone; c.Network != want { + t.Errorf("Network=%v, want: %v", c.Network, want) + } +} + +func TestToFlags(t *testing.T) { + c, err := NewFromFlags() + if err != nil { + t.Fatal(err) + } + c.RootDir = "some-path" + c.Debug = true + c.NumNetworkChannels = 123 + c.Network = NetworkNone + + flags := c.ToFlags() + if len(flags) != 4 { + t.Errorf("wrong number of flags set, want: 4, got: %d: %s", len(flags), flags) + } + t.Logf("Flags: %s", flags) + fm := map[string]string{} + for _, f := range flags { + kv := strings.Split(f, "=") + fm[kv[0]] = kv[1] + } + for name, want := range map[string]string{ + "--root": "some-path", + "--debug": "true", + "--num-network-channels": "123", + "--network": "none", + } { + if got, ok := fm[name]; ok { + if got != want { + t.Errorf("flag %q, want: %q, got: %q", name, want, got) + } + } else { + t.Errorf("flag %q not set", name) + } + } +} + +// TestInvalidFlags checks that enum flags fail when value is not in enum set. +func TestInvalidFlags(t *testing.T) { + for _, tc := range []struct { + name string + error string + }{ + { + name: "file-access", + error: "invalid file access type", + }, + { + name: "network", + error: "invalid network type", + }, + { + name: "qdisc", + error: "invalid qdisc", + }, + { + name: "watchdog-action", + error: "invalid watchdog action", + }, + { + name: "ref-leak-mode", + error: "invalid ref leak mode", + }, + } { + t.Run(tc.name, func(t *testing.T) { + defer setDefault(tc.name) + if err := flag.CommandLine.Lookup(tc.name).Value.Set("invalid"); err == nil || !strings.Contains(err.Error(), tc.error) { + t.Errorf("flag.Value.Set(invalid) wrong error reported: %v", err) + } + }) + } +} + +func TestValidationFail(t *testing.T) { + for _, tc := range []struct { + name string + flags map[string]string + error string + }{ + { + name: "shared+overlay", + flags: map[string]string{ + "file-access": "shared", + "overlay": "true", + }, + error: "overlay flag is incompatible", + }, + { + name: "network-channels", + flags: map[string]string{ + "num-network-channels": "-1", + }, + error: "num_network_channels must be > 0", + }, + } { + t.Run(tc.name, func(t *testing.T) { + for name, val := range tc.flags { + defer setDefault(name) + if err := flag.CommandLine.Lookup(name).Value.Set(val); err != nil { + t.Errorf("%s=%q: %v", name, val, err) + } + } + if _, err := NewFromFlags(); err == nil || !strings.Contains(err.Error(), tc.error) { + t.Errorf("NewFromFlags() wrong error reported: %v", err) + } + }) + } +} diff --git a/runsc/config/flags.go b/runsc/config/flags.go new file mode 100644 index 000000000..488a4b9fb --- /dev/null +++ b/runsc/config/flags.go @@ -0,0 +1,168 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package config + +import ( + "fmt" + "os" + "path/filepath" + "reflect" + "strconv" + + "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sentry/watchdog" + "gvisor.dev/gvisor/pkg/sync" + "gvisor.dev/gvisor/runsc/flag" +) + +var registration sync.Once + +// This is the set of flags used to populate Config. +func RegisterFlags() { + registration.Do(func() { + // Although these flags are not part of the OCI spec, they are used by + // Docker, and thus should not be changed. + flag.String("root", "", "root directory for storage of container state.") + flag.String("log", "", "file path where internal debug information is written, default is stdout.") + flag.String("log-format", "text", "log format: text (default), json, or json-k8s.") + flag.Bool("debug", false, "enable debug logging.") + + // These flags are unique to runsc, and are used to configure parts of the + // system that are not covered by the runtime spec. + + // Debugging flags. + flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.") + flag.String("panic-log", "", "file path were panic reports and other Go's runtime messages are written.") + flag.Bool("log-packets", false, "enable network packet logging.") + flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.") + flag.Bool("alsologtostderr", false, "send log messages to stderr.") + + // Debugging flags: strace related + flag.Bool("strace", false, "enable strace.") + flag.String("strace-syscalls", "", "comma-separated list of syscalls to trace. If --strace is true and this list is empty, then all syscalls will be traced.") + flag.Uint("strace-log-size", 1024, "default size (in bytes) to log data argument blobs.") + + // Flags that control sandbox runtime behavior. + flag.String("platform", "ptrace", "specifies which platform to use: ptrace (default), kvm.") + flag.Var(watchdogActionPtr(watchdog.LogWarning), "watchdog-action", "sets what action the watchdog takes when triggered: log (default), panic.") + flag.Int("panic-signal", -1, "register signal handling that panics. Usually set to SIGUSR2(12) to troubleshoot hangs. -1 disables it.") + flag.Bool("profile", false, "prepares the sandbox to use Golang profiler. Note that enabling profiler loosens the seccomp protection added to the sandbox (DO NOT USE IN PRODUCTION).") + flag.Bool("rootless", false, "it allows the sandbox to be started with a user that is not root. Sandbox and Gofer processes may run with same privileges as current user.") + flag.Var(leakModePtr(refs.NoLeakChecking), "ref-leak-mode", "sets reference leak check mode: disabled (default), log-names, log-traces.") + flag.Bool("cpu-num-from-quota", false, "set cpu number to cpu quota (least integer greater or equal to quota value, but not less than 2)") + + // Flags that control sandbox runtime behavior: FS related. + flag.Var(fileAccessTypePtr(FileAccessExclusive), "file-access", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.") + flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.") + flag.Bool("overlayfs-stale-read", true, "assume root mount is an overlay filesystem") + flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.") + flag.Bool("vfs2", false, "TEST ONLY; use while VFSv2 is landing. This uses the new experimental VFS layer.") + flag.Bool("fuse", false, "TEST ONLY; use while FUSE in VFSv2 is landing. This allows the use of the new experimental FUSE filesystem.") + + // Flags that control sandbox runtime behavior: network related. + flag.Var(networkTypePtr(NetworkSandbox), "network", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.") + flag.Bool("net-raw", false, "enable raw sockets. When false, raw sockets are disabled by removing CAP_NET_RAW from containers (`runsc exec` will still be able to utilize raw sockets). Raw sockets allow malicious containers to craft packets and potentially attack the network.") + flag.Bool("gso", true, "enable hardware segmentation offload if it is supported by a network device.") + flag.Bool("software-gso", true, "enable software segmentation offload when hardware offload can't be enabled.") + flag.Bool("tx-checksum-offload", false, "enable TX checksum offload.") + flag.Bool("rx-checksum-offload", true, "enable RX checksum offload.") + flag.Var(queueingDisciplinePtr(QDiscFIFO), "qdisc", "specifies which queueing discipline to apply by default to the non loopback nics used by the sandbox.") + flag.Int("num-network-channels", 1, "number of underlying channels(FDs) to use for network link endpoints.") + + // Test flags, not to be used outside tests, ever. + flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.") + flag.String("TESTONLY-test-name-env", "", "TEST ONLY; do not ever use! Used for automated tests to improve logging.") + }) +} + +// NewFromFlags creates a new Config with values coming from command line flags. +func NewFromFlags() (*Config, error) { + conf := &Config{} + + obj := reflect.ValueOf(conf).Elem() + st := obj.Type() + for i := 0; i < st.NumField(); i++ { + f := st.Field(i) + name, ok := f.Tag.Lookup("flag") + if !ok { + // No flag set for this field. + continue + } + fl := flag.CommandLine.Lookup(name) + if fl == nil { + panic(fmt.Sprintf("Flag %q not found", name)) + } + x := reflect.ValueOf(flag.Get(fl.Value)) + obj.Field(i).Set(x) + } + + if len(conf.RootDir) == 0 { + // If not set, set default root dir to something (hopefully) user-writeable. + conf.RootDir = "/var/run/runsc" + if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" { + conf.RootDir = filepath.Join(runtimeDir, "runsc") + } + } + + if err := conf.validate(); err != nil { + return nil, err + } + return conf, nil +} + +// ToFlags returns a slice of flags that correspond to the given Config. +func (c *Config) ToFlags() []string { + var rv []string + + obj := reflect.ValueOf(c).Elem() + st := obj.Type() + for i := 0; i < st.NumField(); i++ { + f := st.Field(i) + name, ok := f.Tag.Lookup("flag") + if !ok { + // No flag set for this field. + continue + } + val := getVal(obj.Field(i)) + + flag := flag.CommandLine.Lookup(name) + if flag == nil { + panic(fmt.Sprintf("Flag %q not found", name)) + } + if val == flag.DefValue { + continue + } + rv = append(rv, fmt.Sprintf("--%s=%s", flag.Name, val)) + } + return rv +} + +func getVal(field reflect.Value) string { + if str, ok := field.Addr().Interface().(fmt.Stringer); ok { + return str.String() + } + switch field.Kind() { + case reflect.Bool: + return strconv.FormatBool(field.Bool()) + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return strconv.FormatInt(field.Int(), 10) + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return strconv.FormatUint(field.Uint(), 10) + case reflect.String: + return field.String() + default: + panic("unknown type " + field.Kind().String()) + } +} diff --git a/runsc/flag/flag.go b/runsc/flag/flag.go index 0ca4829d7..ba1ff833f 100644 --- a/runsc/flag/flag.go +++ b/runsc/flag/flag.go @@ -21,13 +21,19 @@ import ( type FlagSet = flag.FlagSet var ( - NewFlagSet = flag.NewFlagSet - String = flag.String Bool = flag.Bool - Int = flag.Int - Uint = flag.Uint CommandLine = flag.CommandLine + Int = flag.Int + NewFlagSet = flag.NewFlagSet Parse = flag.Parse + String = flag.String + Uint = flag.Uint + Var = flag.Var ) const ContinueOnError = flag.ContinueOnError + +// Get returns the flag's underlying object. +func Get(v flag.Value) interface{} { + return v.(flag.Getter).Get() +} diff --git a/runsc/main.go b/runsc/main.go index c2ffecbdc..ed244c4ba 100644 --- a/runsc/main.go +++ b/runsc/main.go @@ -23,8 +23,6 @@ import ( "io/ioutil" "os" "os/signal" - "path/filepath" - "strings" "syscall" "time" @@ -41,58 +39,17 @@ import ( var ( // Although these flags are not part of the OCI spec, they are used by // Docker, and thus should not be changed. - rootDir = flag.String("root", "", "root directory for storage of container state.") - logFilename = flag.String("log", "", "file path where internal debug information is written, default is stdout.") - logFormat = flag.String("log-format", "text", "log format: text (default), json, or json-k8s.") - debug = flag.Bool("debug", false, "enable debug logging.") - showVersion = flag.Bool("version", false, "show version and exit.") // TODO(gvisor.dev/issue/193): support systemd cgroups systemdCgroup = flag.Bool("systemd-cgroup", false, "Use systemd for cgroups. NOT SUPPORTED.") + showVersion = flag.Bool("version", false, "show version and exit.") // These flags are unique to runsc, and are used to configure parts of the // system that are not covered by the runtime spec. // Debugging flags. - debugLog = flag.String("debug-log", "", "additional location for logs. If it ends with '/', log files are created inside the directory with default names. The following variables are available: %TIMESTAMP%, %COMMAND%.") - panicLog = flag.String("panic-log", "", "file path were panic reports and other Go's runtime messages are written.") - logPackets = flag.Bool("log-packets", false, "enable network packet logging.") - logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.") - debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.") - panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.") - debugLogFormat = flag.String("debug-log-format", "text", "log format: text (default), json, or json-k8s.") - alsoLogToStderr = flag.Bool("alsologtostderr", false, "send log messages to stderr.") - - // Debugging flags: strace related - strace = flag.Bool("strace", false, "enable strace.") - straceSyscalls = flag.String("strace-syscalls", "", "comma-separated list of syscalls to trace. If --strace is true and this list is empty, then all syscalls will be traced.") - straceLogSize = flag.Uint("strace-log-size", 1024, "default size (in bytes) to log data argument blobs.") - - // Flags that control sandbox runtime behavior. - platformName = flag.String("platform", "ptrace", "specifies which platform to use: ptrace (default), kvm.") - network = flag.String("network", "sandbox", "specifies which network to use: sandbox (default), host, none. Using network inside the sandbox is more secure because it's isolated from the host network.") - hardwareGSO = flag.Bool("gso", true, "enable hardware segmentation offload if it is supported by a network device.") - softwareGSO = flag.Bool("software-gso", true, "enable software segmentation offload when hardware offload can't be enabled.") - txChecksumOffload = flag.Bool("tx-checksum-offload", false, "enable TX checksum offload.") - rxChecksumOffload = flag.Bool("rx-checksum-offload", true, "enable RX checksum offload.") - qDisc = flag.String("qdisc", "fifo", "specifies which queueing discipline to apply by default to the non loopback nics used by the sandbox.") - fileAccess = flag.String("file-access", "exclusive", "specifies which filesystem to use for the root mount: exclusive (default), shared. Volume mounts are always shared.") - fsGoferHostUDS = flag.Bool("fsgofer-host-uds", false, "allow the gofer to mount Unix Domain Sockets.") - overlay = flag.Bool("overlay", false, "wrap filesystem mounts with writable overlay. All modifications are stored in memory inside the sandbox.") - overlayfsStaleRead = flag.Bool("overlayfs-stale-read", true, "assume root mount is an overlay filesystem") - watchdogAction = flag.String("watchdog-action", "log", "sets what action the watchdog takes when triggered: log (default), panic.") - panicSignal = flag.Int("panic-signal", -1, "register signal handling that panics. Usually set to SIGUSR2(12) to troubleshoot hangs. -1 disables it.") - profile = flag.Bool("profile", false, "prepares the sandbox to use Golang profiler. Note that enabling profiler loosens the seccomp protection added to the sandbox (DO NOT USE IN PRODUCTION).") - netRaw = flag.Bool("net-raw", false, "enable raw sockets. When false, raw sockets are disabled by removing CAP_NET_RAW from containers (`runsc exec` will still be able to utilize raw sockets). Raw sockets allow malicious containers to craft packets and potentially attack the network.") - numNetworkChannels = flag.Int("num-network-channels", 1, "number of underlying channels(FDs) to use for network link endpoints.") - rootless = flag.Bool("rootless", false, "it allows the sandbox to be started with a user that is not root. Sandbox and Gofer processes may run with same privileges as current user.") - referenceLeakMode = flag.String("ref-leak-mode", "disabled", "sets reference leak check mode: disabled (default), log-names, log-traces.") - cpuNumFromQuota = flag.Bool("cpu-num-from-quota", false, "set cpu number to cpu quota (least integer greater or equal to quota value, but not less than 2)") - vfs2Enabled = flag.Bool("vfs2", false, "TEST ONLY; use while VFSv2 is landing. This uses the new experimental VFS layer.") - fuseEnabled = flag.Bool("fuse", false, "TEST ONLY; use while FUSE in VFSv2 is landing. This allows the use of the new experimental FUSE filesystem.") - - // Test flags, not to be used outside tests, ever. - testOnlyAllowRunAsCurrentUserWithoutChroot = flag.Bool("TESTONLY-unsafe-nonroot", false, "TEST ONLY; do not ever use! This skips many security measures that isolate the host from the sandbox.") - testOnlyTestNameEnv = flag.String("TESTONLY-test-name-env", "", "TEST ONLY; do not ever use! Used for automated tests to improve logging.") + logFD = flag.Int("log-fd", -1, "file descriptor to log to. If set, the 'log' flag is ignored.") + debugLogFD = flag.Int("debug-log-fd", -1, "file descriptor to write debug logs to. If set, the 'debug-log-dir' flag is ignored.") + panicLogFD = flag.Int("panic-log-fd", -1, "file descriptor to write Go's runtime messages.") ) func main() { @@ -136,6 +93,8 @@ func main() { subcommands.Register(new(cmd.Gofer), internalGroup) subcommands.Register(new(cmd.Statefile), internalGroup) + config.RegisterFlags() + // All subcommands must be registered before flag parsing. flag.Parse() @@ -147,6 +106,12 @@ func main() { os.Exit(0) } + // Create a new Config from the flags. + conf, err := config.NewFromFlags() + if err != nil { + cmd.Fatalf(err.Error()) + } + // TODO(gvisor.dev/issue/193): support systemd cgroups if *systemdCgroup { fmt.Fprintln(os.Stderr, "systemd cgroup flag passed, but systemd cgroups not supported. See gvisor.dev/issue/193") @@ -157,103 +122,28 @@ func main() { if *logFD > -1 { errorLogger = os.NewFile(uintptr(*logFD), "error log file") - } else if *logFilename != "" { + } else if conf.LogFilename != "" { // We must set O_APPEND and not O_TRUNC because Docker passes // the same log file for all commands (and also parses these // log files), so we can't destroy them on each command. var err error - errorLogger, err = os.OpenFile(*logFilename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) + errorLogger, err = os.OpenFile(conf.LogFilename, os.O_WRONLY|os.O_CREATE|os.O_APPEND, 0644) if err != nil { - cmd.Fatalf("error opening log file %q: %v", *logFilename, err) + cmd.Fatalf("error opening log file %q: %v", conf.LogFilename, err) } } cmd.ErrorLogger = errorLogger - platformType := *platformName - if _, err := platform.Lookup(platformType); err != nil { - cmd.Fatalf("%v", err) - } - - fsAccess, err := config.MakeFileAccessType(*fileAccess) - if err != nil { - cmd.Fatalf("%v", err) - } - - if fsAccess == config.FileAccessShared && *overlay { - cmd.Fatalf("overlay flag is incompatible with shared file access") - } - - netType, err := config.MakeNetworkType(*network) - if err != nil { + if _, err := platform.Lookup(conf.Platform); err != nil { cmd.Fatalf("%v", err) } - wa, err := config.MakeWatchdogAction(*watchdogAction) - if err != nil { - cmd.Fatalf("%v", err) - } - - if *numNetworkChannels <= 0 { - cmd.Fatalf("num_network_channels must be > 0, got: %d", *numNetworkChannels) - } - - refsLeakMode, err := config.MakeRefsLeakMode(*referenceLeakMode) - if err != nil { - cmd.Fatalf("%v", err) - } - - queueingDiscipline, err := config.MakeQueueingDiscipline(*qDisc) - if err != nil { - cmd.Fatalf("%s", err) - } - // Sets the reference leak check mode. Also set it in config below to // propagate it to child processes. - refs.SetLeakMode(refsLeakMode) - - // Create a new Config from the flags. - conf := &config.Config{ - RootDir: *rootDir, - Debug: *debug, - LogFilename: *logFilename, - LogFormat: *logFormat, - DebugLog: *debugLog, - PanicLog: *panicLog, - DebugLogFormat: *debugLogFormat, - FileAccess: fsAccess, - FSGoferHostUDS: *fsGoferHostUDS, - Overlay: *overlay, - Network: netType, - HardwareGSO: *hardwareGSO, - SoftwareGSO: *softwareGSO, - TXChecksumOffload: *txChecksumOffload, - RXChecksumOffload: *rxChecksumOffload, - LogPackets: *logPackets, - Platform: platformType, - Strace: *strace, - StraceLogSize: *straceLogSize, - WatchdogAction: wa, - PanicSignal: *panicSignal, - ProfileEnable: *profile, - EnableRaw: *netRaw, - NumNetworkChannels: *numNetworkChannels, - Rootless: *rootless, - AlsoLogToStderr: *alsoLogToStderr, - ReferenceLeakMode: refsLeakMode, - OverlayfsStaleRead: *overlayfsStaleRead, - CPUNumFromQuota: *cpuNumFromQuota, - VFS2: *vfs2Enabled, - FUSE: *fuseEnabled, - QDisc: queueingDiscipline, - TestOnlyAllowRunAsCurrentUserWithoutChroot: *testOnlyAllowRunAsCurrentUserWithoutChroot, - TestOnlyTestNameEnv: *testOnlyTestNameEnv, - } - if len(*straceSyscalls) != 0 { - conf.StraceSyscalls = strings.Split(*straceSyscalls, ",") - } + refs.SetLeakMode(conf.ReferenceLeak) // Set up logging. - if *debug { + if conf.Debug { log.SetLevel(log.Debug) } @@ -275,14 +165,14 @@ func main() { if *debugLogFD > -1 { f := os.NewFile(uintptr(*debugLogFD), "debug log file") - e = newEmitter(*debugLogFormat, f) + e = newEmitter(conf.DebugLogFormat, f) - } else if *debugLog != "" { - f, err := specutils.DebugLogFile(*debugLog, subcommand, "" /* name */) + } else if conf.DebugLog != "" { + f, err := specutils.DebugLogFile(conf.DebugLog, subcommand, "" /* name */) if err != nil { - cmd.Fatalf("error opening debug log file in %q: %v", *debugLog, err) + cmd.Fatalf("error opening debug log file in %q: %v", conf.DebugLog, err) } - e = newEmitter(*debugLogFormat, f) + e = newEmitter(conf.DebugLogFormat, f) } else { // Stderr is reserved for the application, just discard the logs if no debug @@ -308,8 +198,8 @@ func main() { if err := syscall.Dup3(fd, int(os.Stderr.Fd()), 0); err != nil { cmd.Fatalf("error dup'ing fd %d to stderr: %v", fd, err) } - } else if *alsoLogToStderr { - e = &log.MultiEmitter{e, newEmitter(*debugLogFormat, os.Stderr)} + } else if conf.AlsoLogToStderr { + e = &log.MultiEmitter{e, newEmitter(conf.DebugLogFormat, os.Stderr)} } log.SetTarget(e) @@ -328,7 +218,7 @@ func main() { log.Infof("\t\tVFS2 enabled: %v", conf.VFS2) log.Infof("***************************") - if *testOnlyAllowRunAsCurrentUserWithoutChroot { + if conf.TestOnlyAllowRunAsCurrentUserWithoutChroot { // SIGTERM is sent to all processes if a test exceeds its // timeout and this case is handled by syscall_test_runner. log.Warningf("Block the TERM signal. This is only safe in tests!") @@ -364,11 +254,3 @@ func newEmitter(format string, logFile io.Writer) log.Emitter { cmd.Fatalf("invalid log format %q, must be 'text', 'json', or 'json-k8s'", format) panic("unreachable") } - -func init() { - // Set default root dir to something (hopefully) user-writeable. - *rootDir = "/var/run/runsc" - if runtimeDir := os.Getenv("XDG_RUNTIME_DIR"); runtimeDir != "" { - *rootDir = filepath.Join(runtimeDir, "runsc") - } -} diff --git a/runsc/sandbox/network.go b/runsc/sandbox/network.go index f9abb2d44..0b9f39466 100644 --- a/runsc/sandbox/network.go +++ b/runsc/sandbox/network.go @@ -69,7 +69,7 @@ func setupNetwork(conn *urpc.Client, pid int, spec *specs.Spec, conf *config.Con case config.NetworkHost: // Nothing to do here. default: - return fmt.Errorf("invalid network type: %d", conf.Network) + return fmt.Errorf("invalid network type: %v", conf.Network) } return nil } diff --git a/test/syscalls/linux/ip6tables.cc b/test/syscalls/linux/ip6tables.cc index 685e513f8..78e1fa09d 100644 --- a/test/syscalls/linux/ip6tables.cc +++ b/test/syscalls/linux/ip6tables.cc @@ -34,6 +34,54 @@ constexpr size_t kEmptyStandardEntrySize = constexpr size_t kEmptyErrorEntrySize = sizeof(struct ip6t_entry) + sizeof(struct xt_error_target); +TEST(IP6TablesBasic, FailSockoptNonRaw) { + // Even if the user has CAP_NET_RAW, they shouldn't be able to use the + // ip6tables sockopts with a non-raw socket. + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_DGRAM, 0), SyscallSucceeds()); + + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info); + EXPECT_THAT(getsockopt(sock, SOL_IPV6, IP6T_SO_GET_INFO, &info, &info_size), + SyscallFailsWithErrno(ENOPROTOOPT)); + + EXPECT_THAT(close(sock), SyscallSucceeds()); +} + +TEST(IP6TablesBasic, GetInfoErrorPrecedence) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_DGRAM, 0), SyscallSucceeds()); + + // When using the wrong type of socket and a too-short optlen, we should get + // EINVAL. + struct ipt_getinfo info = {}; + snprintf(info.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + socklen_t info_size = sizeof(info) - 1; + EXPECT_THAT(getsockopt(sock, SOL_IPV6, IP6T_SO_GET_INFO, &info, &info_size), + SyscallFailsWithErrno(EINVAL)); +} + +TEST(IP6TablesBasic, GetEntriesErrorPrecedence) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_RAW))); + + int sock; + ASSERT_THAT(sock = socket(AF_INET6, SOCK_DGRAM, 0), SyscallSucceeds()); + + // When using the wrong type of socket and a too-short optlen, we should get + // EINVAL. + struct ip6t_get_entries entries = {}; + socklen_t entries_size = sizeof(struct ip6t_get_entries) - 1; + snprintf(entries.name, XT_TABLE_MAXNAMELEN, "%s", kNatTablename); + EXPECT_THAT( + getsockopt(sock, SOL_IPV6, IP6T_SO_GET_ENTRIES, &entries, &entries_size), + SyscallFailsWithErrno(EINVAL)); +} + // This tests the initial state of a machine with empty ip6tables via // getsockopt(IP6T_SO_GET_INFO). We don't have a guarantee that the iptables are // empty when running in native, but we can test that gVisor has the same diff --git a/test/syscalls/linux/pty.cc b/test/syscalls/linux/pty.cc index f9392b9e0..2e4ab6ca8 100644 --- a/test/syscalls/linux/pty.cc +++ b/test/syscalls/linux/pty.cc @@ -51,6 +51,7 @@ using ::testing::AnyOf; using ::testing::Contains; using ::testing::Eq; using ::testing::Not; +using SubprocessCallback = std::function<void()>; // Tests Unix98 pseudoterminals. // @@ -70,7 +71,7 @@ constexpr absl::Duration kTimeout = absl::Seconds(20); // The maximum line size in bytes returned per read from a pty file. constexpr int kMaxLineSize = 4096; -constexpr char kMasterPath[] = "/dev/ptmx"; +constexpr char kMainPath[] = "/dev/ptmx"; // glibc defines its own, different, version of struct termios. We care about // what the kernel does, not glibc. @@ -387,22 +388,22 @@ PosixErrorOr<size_t> PollAndReadFd(int fd, void* buf, size_t count, TEST(PtyTrunc, Truncate) { // Opening PTYs with O_TRUNC shouldn't cause an error, but calls to // (f)truncate should. - FileDescriptor master = - ASSERT_NO_ERRNO_AND_VALUE(Open(kMasterPath, O_RDWR | O_TRUNC)); - int n = ASSERT_NO_ERRNO_AND_VALUE(SlaveID(master)); + FileDescriptor main = + ASSERT_NO_ERRNO_AND_VALUE(Open(kMainPath, O_RDWR | O_TRUNC)); + int n = ASSERT_NO_ERRNO_AND_VALUE(ReplicaID(main)); std::string spath = absl::StrCat("/dev/pts/", n); - FileDescriptor slave = + FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(Open(spath, O_RDWR | O_NONBLOCK | O_TRUNC)); - EXPECT_THAT(truncate(kMasterPath, 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(truncate(kMainPath, 0), SyscallFailsWithErrno(EINVAL)); EXPECT_THAT(truncate(spath.c_str(), 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(ftruncate(master.get(), 0), SyscallFailsWithErrno(EINVAL)); - EXPECT_THAT(ftruncate(slave.get(), 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(ftruncate(main.get(), 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(ftruncate(replica.get(), 0), SyscallFailsWithErrno(EINVAL)); } -TEST(BasicPtyTest, StatUnopenedMaster) { +TEST(BasicPtyTest, StatUnopenedMain) { struct stat s; - ASSERT_THAT(stat(kMasterPath, &s), SyscallSucceeds()); + ASSERT_THAT(stat(kMainPath, &s), SyscallSucceeds()); EXPECT_EQ(s.st_rdev, makedev(TTYAUX_MAJOR, kPtmxMinor)); EXPECT_EQ(s.st_size, 0); @@ -453,41 +454,41 @@ void ExpectReadable(const FileDescriptor& fd, int expected, char* buf) { EXPECT_EQ(expected, n); } -TEST(BasicPtyTest, OpenMasterSlave) { - FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); - FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); +TEST(BasicPtyTest, OpenMainReplica) { + FileDescriptor main = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); + FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(main)); } -// The slave entry in /dev/pts/ disappears when the master is closed, even if -// the slave is still open. -TEST(BasicPtyTest, SlaveEntryGoneAfterMasterClose) { - FileDescriptor master = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); - FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); +// The replica entry in /dev/pts/ disappears when the main is closed, even if +// the replica is still open. +TEST(BasicPtyTest, ReplicaEntryGoneAfterMainClose) { + FileDescriptor main = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); + FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(main)); // Get pty index. int index = -1; - ASSERT_THAT(ioctl(master.get(), TIOCGPTN, &index), SyscallSucceeds()); + ASSERT_THAT(ioctl(main.get(), TIOCGPTN, &index), SyscallSucceeds()); std::string path = absl::StrCat("/dev/pts/", index); struct stat st; EXPECT_THAT(stat(path.c_str(), &st), SyscallSucceeds()); - master.reset(); + main.reset(); EXPECT_THAT(stat(path.c_str(), &st), SyscallFailsWithErrno(ENOENT)); } TEST(BasicPtyTest, Getdents) { - FileDescriptor master1 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); + FileDescriptor main1 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); int index1 = -1; - ASSERT_THAT(ioctl(master1.get(), TIOCGPTN, &index1), SyscallSucceeds()); - FileDescriptor slave1 = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master1)); + ASSERT_THAT(ioctl(main1.get(), TIOCGPTN, &index1), SyscallSucceeds()); + FileDescriptor replica1 = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(main1)); - FileDescriptor master2 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); + FileDescriptor main2 = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR)); int index2 = -1; - ASSERT_THAT(ioctl(master2.get(), TIOCGPTN, &index2), SyscallSucceeds()); - FileDescriptor slave2 = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master2)); + ASSERT_THAT(ioctl(main2.get(), TIOCGPTN, &index2), SyscallSucceeds()); + FileDescriptor replica2 = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(main2)); // The directory contains ptmx, index1, and index2. (Plus any additional PTYs // unrelated to this test.) @@ -497,9 +498,9 @@ TEST(BasicPtyTest, Getdents) { EXPECT_THAT(contents, Contains(absl::StrCat(index1))); EXPECT_THAT(contents, Contains(absl::StrCat(index2))); - master2.reset(); + main2.reset(); - // The directory contains ptmx and index1, but not index2 since the master is + // The directory contains ptmx and index1, but not index2 since the main is // closed. (Plus any additional PTYs unrelated to this test.) contents = ASSERT_NO_ERRNO_AND_VALUE(ListDir("/dev/pts/", true)); @@ -518,104 +519,105 @@ TEST(BasicPtyTest, Getdents) { class PtyTest : public ::testing::Test { protected: void SetUp() override { - master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); - slave_ = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master_)); + main_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); + replica_ = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(main_)); } void DisableCanonical() { struct kernel_termios t = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCGETS, &t), SyscallSucceeds()); t.c_lflag &= ~ICANON; - EXPECT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); } void EnableCanonical() { struct kernel_termios t = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCGETS, &t), SyscallSucceeds()); t.c_lflag |= ICANON; - EXPECT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); } - // Master and slave ends of the PTY. Non-blocking. - FileDescriptor master_; - FileDescriptor slave_; + // Main and replica ends of the PTY. Non-blocking. + FileDescriptor main_; + FileDescriptor replica_; }; -// Master to slave sanity test. -TEST_F(PtyTest, WriteMasterToSlave) { - // N.B. by default, the slave reads nothing until the master writes a newline. +// Main to replica sanity test. +TEST_F(PtyTest, WriteMainToReplica) { + // N.B. by default, the replica reads nothing until the main writes a newline. constexpr char kBuf[] = "hello\n"; - EXPECT_THAT(WriteFd(master_.get(), kBuf, sizeof(kBuf) - 1), + EXPECT_THAT(WriteFd(main_.get(), kBuf, sizeof(kBuf) - 1), SyscallSucceedsWithValue(sizeof(kBuf) - 1)); - // Linux moves data from the master to the slave via async work scheduled via + // Linux moves data from the main to the replica via async work scheduled via // tty_flip_buffer_push. Since it is asynchronous, the data may not be // available for reading immediately. Instead we must poll and assert that it // becomes available "soon". char buf[sizeof(kBuf)] = {}; - ExpectReadable(slave_, sizeof(buf) - 1, buf); + ExpectReadable(replica_, sizeof(buf) - 1, buf); EXPECT_EQ(memcmp(buf, kBuf, sizeof(kBuf)), 0); } -// Slave to master sanity test. -TEST_F(PtyTest, WriteSlaveToMaster) { - // N.B. by default, the master reads nothing until the slave writes a newline, - // and the master gets a carriage return. +// Replica to main sanity test. +TEST_F(PtyTest, WriteReplicaToMain) { + // N.B. by default, the main reads nothing until the replica writes a newline, + // and the main gets a carriage return. constexpr char kInput[] = "hello\n"; constexpr char kExpected[] = "hello\r\n"; - EXPECT_THAT(WriteFd(slave_.get(), kInput, sizeof(kInput) - 1), + EXPECT_THAT(WriteFd(replica_.get(), kInput, sizeof(kInput) - 1), SyscallSucceedsWithValue(sizeof(kInput) - 1)); - // Linux moves data from the master to the slave via async work scheduled via + // Linux moves data from the main to the replica via async work scheduled via // tty_flip_buffer_push. Since it is asynchronous, the data may not be // available for reading immediately. Instead we must poll and assert that it // becomes available "soon". char buf[sizeof(kExpected)] = {}; - ExpectReadable(master_, sizeof(buf) - 1, buf); + ExpectReadable(main_, sizeof(buf) - 1, buf); EXPECT_EQ(memcmp(buf, kExpected, sizeof(kExpected)), 0); } TEST_F(PtyTest, WriteInvalidUTF8) { char c = 0xff; - ASSERT_THAT(syscall(__NR_write, master_.get(), &c, sizeof(c)), + ASSERT_THAT(syscall(__NR_write, main_.get(), &c, sizeof(c)), SyscallSucceedsWithValue(sizeof(c))); } -// Both the master and slave report the standard default termios settings. +// Both the main and replica report the standard default termios settings. // -// Note that TCGETS on the master actually redirects to the slave (see comment -// on MasterTermiosUnchangable). +// Note that TCGETS on the main actually redirects to the replica (see comment +// on MainTermiosUnchangable). TEST_F(PtyTest, DefaultTermios) { struct kernel_termios t = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TCGETS, &t), SyscallSucceeds()); EXPECT_EQ(t, DefaultTermios()); - EXPECT_THAT(ioctl(master_.get(), TCGETS, &t), SyscallSucceeds()); + EXPECT_THAT(ioctl(main_.get(), TCGETS, &t), SyscallSucceeds()); EXPECT_EQ(t, DefaultTermios()); } -// Changing termios from the master actually affects the slave. +// Changing termios from the main actually affects the replica. // -// TCSETS on the master actually redirects to the slave (see comment on -// MasterTermiosUnchangable). -TEST_F(PtyTest, TermiosAffectsSlave) { - struct kernel_termios master_termios = {}; - EXPECT_THAT(ioctl(master_.get(), TCGETS, &master_termios), SyscallSucceeds()); - master_termios.c_lflag ^= ICANON; - EXPECT_THAT(ioctl(master_.get(), TCSETS, &master_termios), SyscallSucceeds()); - - struct kernel_termios slave_termios = {}; - EXPECT_THAT(ioctl(slave_.get(), TCGETS, &slave_termios), SyscallSucceeds()); - EXPECT_EQ(master_termios, slave_termios); +// TCSETS on the main actually redirects to the replica (see comment on +// MainTermiosUnchangable). +TEST_F(PtyTest, TermiosAffectsReplica) { + struct kernel_termios main_termios = {}; + EXPECT_THAT(ioctl(main_.get(), TCGETS, &main_termios), SyscallSucceeds()); + main_termios.c_lflag ^= ICANON; + EXPECT_THAT(ioctl(main_.get(), TCSETS, &main_termios), SyscallSucceeds()); + + struct kernel_termios replica_termios = {}; + EXPECT_THAT(ioctl(replica_.get(), TCGETS, &replica_termios), + SyscallSucceeds()); + EXPECT_EQ(main_termios, replica_termios); } -// The master end of the pty has termios: +// The main end of the pty has termios: // // struct kernel_termios t = { // .c_iflag = 0; @@ -627,25 +629,25 @@ TEST_F(PtyTest, TermiosAffectsSlave) { // // (From drivers/tty/pty.c:unix98_pty_init) // -// All termios control ioctls on the master actually redirect to the slave +// All termios control ioctls on the main actually redirect to the replica // (drivers/tty/tty_ioctl.c:tty_mode_ioctl), making it impossible to change the -// master termios. +// main termios. // // Verify this by setting ICRNL (which rewrites input \r to \n) and verify that -// it has no effect on the master. -TEST_F(PtyTest, MasterTermiosUnchangable) { - struct kernel_termios master_termios = {}; - EXPECT_THAT(ioctl(master_.get(), TCGETS, &master_termios), SyscallSucceeds()); - master_termios.c_lflag |= ICRNL; - EXPECT_THAT(ioctl(master_.get(), TCSETS, &master_termios), SyscallSucceeds()); +// it has no effect on the main. +TEST_F(PtyTest, MainTermiosUnchangable) { + struct kernel_termios main_termios = {}; + EXPECT_THAT(ioctl(main_.get(), TCGETS, &main_termios), SyscallSucceeds()); + main_termios.c_lflag |= ICRNL; + EXPECT_THAT(ioctl(main_.get(), TCSETS, &main_termios), SyscallSucceeds()); char c = '\r'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); - ExpectReadable(master_, 1, &c); + ExpectReadable(main_, 1, &c); EXPECT_EQ(c, '\r'); // ICRNL had no effect! - ExpectFinished(master_); + ExpectFinished(main_); } // ICRNL rewrites input \r to \n. @@ -653,15 +655,15 @@ TEST_F(PtyTest, TermiosICRNL) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= ICRNL; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); char c = '\r'; - ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(main_.get(), &c, 1), SyscallSucceedsWithValue(1)); - ExpectReadable(slave_, 1, &c); + ExpectReadable(replica_, 1, &c); EXPECT_EQ(c, '\n'); - ExpectFinished(slave_); + ExpectFinished(replica_); } // ONLCR rewrites output \n to \r\n. @@ -669,42 +671,42 @@ TEST_F(PtyTest, TermiosONLCR) { struct kernel_termios t = DefaultTermios(); t.c_oflag |= ONLCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); char c = '\n'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); // Extra byte for NUL for EXPECT_STREQ. char buf[3] = {}; - ExpectReadable(master_, 2, buf); + ExpectReadable(main_, 2, buf); EXPECT_STREQ(buf, "\r\n"); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, TermiosIGNCR) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= IGNCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); char c = '\r'; - ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(main_.get(), &c, 1), SyscallSucceedsWithValue(1)); // Nothing to read. - ASSERT_THAT(PollAndReadFd(slave_.get(), &c, 1, kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), &c, 1, kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); } -// Test that we can successfully poll for readable data from the slave. -TEST_F(PtyTest, TermiosPollSlave) { +// Test that we can successfully poll for readable data from the replica. +TEST_F(PtyTest, TermiosPollReplica) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= IGNCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); absl::Notification notify; - int sfd = slave_.get(); + int sfd = replica_.get(); ScopedThread th([sfd, ¬ify]() { notify.Notify(); @@ -723,18 +725,18 @@ TEST_F(PtyTest, TermiosPollSlave) { absl::SleepFor(absl::Seconds(1)); char s[] = "foo\n"; - ASSERT_THAT(WriteFd(master_.get(), s, strlen(s) + 1), SyscallSucceeds()); + ASSERT_THAT(WriteFd(main_.get(), s, strlen(s) + 1), SyscallSucceeds()); } -// Test that we can successfully poll for readable data from the master. -TEST_F(PtyTest, TermiosPollMaster) { +// Test that we can successfully poll for readable data from the main. +TEST_F(PtyTest, TermiosPollMain) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= IGNCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(master_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(main_.get(), TCSETS, &t), SyscallSucceeds()); absl::Notification notify; - int mfd = master_.get(); + int mfd = main_.get(); ScopedThread th([mfd, ¬ify]() { notify.Notify(); @@ -753,57 +755,57 @@ TEST_F(PtyTest, TermiosPollMaster) { absl::SleepFor(absl::Seconds(1)); char s[] = "foo\n"; - ASSERT_THAT(WriteFd(slave_.get(), s, strlen(s) + 1), SyscallSucceeds()); + ASSERT_THAT(WriteFd(replica_.get(), s, strlen(s) + 1), SyscallSucceeds()); } TEST_F(PtyTest, TermiosINLCR) { struct kernel_termios t = DefaultTermios(); t.c_iflag |= INLCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); char c = '\n'; - ASSERT_THAT(WriteFd(master_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(main_.get(), &c, 1), SyscallSucceedsWithValue(1)); - ExpectReadable(slave_, 1, &c); + ExpectReadable(replica_, 1, &c); EXPECT_EQ(c, '\r'); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, TermiosONOCR) { struct kernel_termios t = DefaultTermios(); t.c_oflag |= ONOCR; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); // The terminal is at column 0, so there should be no CR to read. char c = '\r'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); // Nothing to read. - ASSERT_THAT(PollAndReadFd(master_.get(), &c, 1, kTimeout), + ASSERT_THAT(PollAndReadFd(main_.get(), &c, 1, kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); // This time the column is greater than 0, so we should be able to read the CR // out of the other end. constexpr char kInput[] = "foo\r"; constexpr int kInputSize = sizeof(kInput) - 1; - ASSERT_THAT(WriteFd(slave_.get(), kInput, kInputSize), + ASSERT_THAT(WriteFd(replica_.get(), kInput, kInputSize), SyscallSucceedsWithValue(kInputSize)); char buf[kInputSize] = {}; - ExpectReadable(master_, kInputSize, buf); + ExpectReadable(main_, kInputSize, buf); EXPECT_EQ(memcmp(buf, kInput, kInputSize), 0); - ExpectFinished(master_); + ExpectFinished(main_); // Terminal should be at column 0 again, so no CR can be read. - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); // Nothing to read. - ASSERT_THAT(PollAndReadFd(master_.get(), &c, 1, kTimeout), + ASSERT_THAT(PollAndReadFd(main_.get(), &c, 1, kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); } @@ -811,16 +813,16 @@ TEST_F(PtyTest, TermiosOCRNL) { struct kernel_termios t = DefaultTermios(); t.c_oflag |= OCRNL; t.c_lflag &= ~ICANON; // for byte-by-byte reading. - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); // The terminal is at column 0, so there should be no CR to read. char c = '\r'; - ASSERT_THAT(WriteFd(slave_.get(), &c, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(replica_.get(), &c, 1), SyscallSucceedsWithValue(1)); - ExpectReadable(master_, 1, &c); + ExpectReadable(main_, 1, &c); EXPECT_EQ(c, '\n'); - ExpectFinished(master_); + ExpectFinished(main_); } // Tests that VEOL is disabled when we start, and that we can set it to enable @@ -828,27 +830,27 @@ TEST_F(PtyTest, TermiosOCRNL) { TEST_F(PtyTest, VEOLTermination) { // Write a few bytes ending with '\0', and confirm that we can't read. constexpr char kInput[] = "hello"; - ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput)), + ASSERT_THAT(WriteFd(main_.get(), kInput, sizeof(kInput)), SyscallSucceedsWithValue(sizeof(kInput))); char buf[sizeof(kInput)] = {}; - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(kInput), kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(kInput), kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); // Set the EOL character to '=' and write it. constexpr char delim = '='; struct kernel_termios t = DefaultTermios(); t.c_cc[VEOL] = delim; - ASSERT_THAT(ioctl(slave_.get(), TCSETS, &t), SyscallSucceeds()); - ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(ioctl(replica_.get(), TCSETS, &t), SyscallSucceeds()); + ASSERT_THAT(WriteFd(main_.get(), &delim, 1), SyscallSucceedsWithValue(1)); // Now we can read, as sending EOL caused the line to become available. - ExpectReadable(slave_, sizeof(kInput), buf); + ExpectReadable(replica_, sizeof(kInput), buf); EXPECT_EQ(memcmp(buf, kInput, sizeof(kInput)), 0); - ExpectReadable(slave_, 1, buf); + ExpectReadable(replica_, 1, buf); EXPECT_EQ(buf[0], '='); - ExpectFinished(slave_); + ExpectFinished(replica_); } // Tests that we can write more than the 4096 character limit, then a @@ -859,14 +861,14 @@ TEST_F(PtyTest, CanonBigWrite) { char input[kWriteLen]; memset(input, 'M', kWriteLen - 1); input[kWriteLen - 1] = '\n'; - ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen), + ASSERT_THAT(WriteFd(main_.get(), input, kWriteLen), SyscallSucceedsWithValue(kWriteLen)); // We can read the line. char buf[kMaxLineSize] = {}; - ExpectReadable(slave_, kMaxLineSize, buf); + ExpectReadable(replica_, kMaxLineSize, buf); - ExpectFinished(slave_); + ExpectFinished(replica_); } // Tests that data written in canonical mode can be read immediately once @@ -875,36 +877,36 @@ TEST_F(PtyTest, SwitchCanonToNoncanon) { // Write a few bytes without a terminating character, switch to noncanonical // mode, and read them. constexpr char kInput[] = "hello"; - ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput)), + ASSERT_THAT(WriteFd(main_.get(), kInput, sizeof(kInput)), SyscallSucceedsWithValue(sizeof(kInput))); // Nothing available yet. char buf[sizeof(kInput)] = {}; - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(kInput), kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(kInput), kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); DisableCanonical(); - ExpectReadable(slave_, sizeof(kInput), buf); + ExpectReadable(replica_, sizeof(kInput), buf); EXPECT_STREQ(buf, kInput); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchCanonToNonCanonNewline) { // Write a few bytes with a terminating character. constexpr char kInput[] = "hello\n"; - ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput)), + ASSERT_THAT(WriteFd(main_.get(), kInput, sizeof(kInput)), SyscallSucceedsWithValue(sizeof(kInput))); DisableCanonical(); // We can read the line. char buf[sizeof(kInput)] = {}; - ExpectReadable(slave_, sizeof(kInput), buf); + ExpectReadable(replica_, sizeof(kInput), buf); EXPECT_STREQ(buf, kInput); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchNoncanonToCanonNewlineBig) { @@ -914,23 +916,23 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNewlineBig) { constexpr int kWriteLen = 4100; char input[kWriteLen]; memset(input, 'M', kWriteLen); - ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen), + ASSERT_THAT(WriteFd(main_.get(), input, kWriteLen), SyscallSucceedsWithValue(kWriteLen)); // Wait for the input queue to fill. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1)); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kMaxLineSize - 1)); constexpr char delim = '\n'; - ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(main_.get(), &delim, 1), SyscallSucceedsWithValue(1)); EnableCanonical(); // We can read the line. char buf[kMaxLineSize] = {}; - ExpectReadable(slave_, kMaxLineSize - 1, buf); + ExpectReadable(replica_, kMaxLineSize - 1, buf); // We can also read the remaining characters. - ExpectReadable(slave_, 6, buf); + ExpectReadable(replica_, 6, buf); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchNoncanonToCanonNoNewline) { @@ -939,18 +941,18 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNoNewline) { // Write a few bytes without a terminating character. // mode, and read them. constexpr char kInput[] = "hello"; - ASSERT_THAT(WriteFd(master_.get(), kInput, sizeof(kInput) - 1), + ASSERT_THAT(WriteFd(main_.get(), kInput, sizeof(kInput) - 1), SyscallSucceedsWithValue(sizeof(kInput) - 1)); - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(kInput) - 1)); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), sizeof(kInput) - 1)); EnableCanonical(); // We can read the line. char buf[sizeof(kInput)] = {}; - ExpectReadable(slave_, sizeof(kInput) - 1, buf); + ExpectReadable(replica_, sizeof(kInput) - 1, buf); EXPECT_STREQ(buf, kInput); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchNoncanonToCanonNoNewlineBig) { @@ -961,17 +963,17 @@ TEST_F(PtyTest, SwitchNoncanonToCanonNoNewlineBig) { constexpr int kWriteLen = 4100; char input[kWriteLen]; memset(input, 'M', kWriteLen); - ASSERT_THAT(WriteFd(master_.get(), input, kWriteLen), + ASSERT_THAT(WriteFd(main_.get(), input, kWriteLen), SyscallSucceedsWithValue(kWriteLen)); - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1)); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kMaxLineSize - 1)); EnableCanonical(); // We can read the line. char buf[kMaxLineSize] = {}; - ExpectReadable(slave_, kMaxLineSize - 1, buf); + ExpectReadable(replica_, kMaxLineSize - 1, buf); - ExpectFinished(slave_); + ExpectFinished(replica_); } // Tests that we can write over the 4095 noncanonical limit, then read out @@ -985,22 +987,22 @@ TEST_F(PtyTest, NoncanonBigWrite) { for (int i = 0; i < kInputSize; i++) { // This makes too many syscalls for save/restore. const DisableSave ds; - ASSERT_THAT(WriteFd(master_.get(), &kInput, sizeof(kInput)), + ASSERT_THAT(WriteFd(main_.get(), &kInput, sizeof(kInput)), SyscallSucceedsWithValue(sizeof(kInput))); } // We should be able to read out everything. Sleep a bit so that Linux has a - // chance to move data from the master to the slave. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kMaxLineSize - 1)); + // chance to move data from the main to the replica. + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kMaxLineSize - 1)); for (int i = 0; i < kInputSize; i++) { // This makes too many syscalls for save/restore. const DisableSave ds; char c; - ExpectReadable(slave_, 1, &c); + ExpectReadable(replica_, 1, &c); ASSERT_EQ(c, kInput); } - ExpectFinished(slave_); + ExpectFinished(replica_); } // ICANON doesn't make input available until a line delimiter is typed. @@ -1008,25 +1010,25 @@ TEST_F(PtyTest, NoncanonBigWrite) { // Test newline. TEST_F(PtyTest, TermiosICANONNewline) { char input[3] = {'a', 'b', 'c'}; - ASSERT_THAT(WriteFd(master_.get(), input, sizeof(input)), + ASSERT_THAT(WriteFd(main_.get(), input, sizeof(input)), SyscallSucceedsWithValue(sizeof(input))); // Extra bytes for newline (written later) and NUL for EXPECT_STREQ. char buf[5] = {}; // Nothing available yet. - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(input), kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(input), kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); char delim = '\n'; - ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(main_.get(), &delim, 1), SyscallSucceedsWithValue(1)); // Now it is available. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(input) + 1)); - ExpectReadable(slave_, sizeof(input) + 1, buf); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), sizeof(input) + 1)); + ExpectReadable(replica_, sizeof(input) + 1, buf); EXPECT_STREQ(buf, "abc\n"); - ExpectFinished(slave_); + ExpectFinished(replica_); } // ICANON doesn't make input available until a line delimiter is typed. @@ -1034,23 +1036,23 @@ TEST_F(PtyTest, TermiosICANONNewline) { // Test EOF (^D). TEST_F(PtyTest, TermiosICANONEOF) { char input[3] = {'a', 'b', 'c'}; - ASSERT_THAT(WriteFd(master_.get(), input, sizeof(input)), + ASSERT_THAT(WriteFd(main_.get(), input, sizeof(input)), SyscallSucceedsWithValue(sizeof(input))); // Extra byte for NUL for EXPECT_STREQ. char buf[4] = {}; // Nothing available yet. - ASSERT_THAT(PollAndReadFd(slave_.get(), buf, sizeof(input), kTimeout), + ASSERT_THAT(PollAndReadFd(replica_.get(), buf, sizeof(input), kTimeout), PosixErrorIs(ETIMEDOUT, ::testing::StrEq("Poll timed out"))); char delim = ControlCharacter('D'); - ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(main_.get(), &delim, 1), SyscallSucceedsWithValue(1)); // Now it is available. Note that ^D is not included. - ExpectReadable(slave_, sizeof(input), buf); + ExpectReadable(replica_, sizeof(input), buf); EXPECT_STREQ(buf, "abc"); - ExpectFinished(slave_); + ExpectFinished(replica_); } // ICANON limits us to 4096 bytes including a terminating character. Anything @@ -1067,21 +1069,21 @@ TEST_F(PtyTest, CanonDiscard) { // This makes too many syscalls for save/restore. const DisableSave ds; for (int i = 0; i < kInputSize; i++) { - ASSERT_THAT(WriteFd(master_.get(), &kInput, sizeof(kInput)), + ASSERT_THAT(WriteFd(main_.get(), &kInput, sizeof(kInput)), SyscallSucceedsWithValue(sizeof(kInput))); } - ASSERT_THAT(WriteFd(master_.get(), &delim, 1), SyscallSucceedsWithValue(1)); + ASSERT_THAT(WriteFd(main_.get(), &delim, 1), SyscallSucceedsWithValue(1)); } // There should be multiple truncated lines available to read. for (int i = 0; i < kIter; i++) { char buf[kInputSize] = {}; - ExpectReadable(slave_, kMaxLineSize, buf); + ExpectReadable(replica_, kMaxLineSize, buf); EXPECT_EQ(buf[kMaxLineSize - 1], delim); EXPECT_EQ(buf[kMaxLineSize - 2], kInput); } - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, CanonMultiline) { @@ -1089,22 +1091,22 @@ TEST_F(PtyTest, CanonMultiline) { constexpr char kInput2[] = "BLUE\n"; // Write both lines. - ASSERT_THAT(WriteFd(master_.get(), kInput1, sizeof(kInput1) - 1), + ASSERT_THAT(WriteFd(main_.get(), kInput1, sizeof(kInput1) - 1), SyscallSucceedsWithValue(sizeof(kInput1) - 1)); - ASSERT_THAT(WriteFd(master_.get(), kInput2, sizeof(kInput2) - 1), + ASSERT_THAT(WriteFd(main_.get(), kInput2, sizeof(kInput2) - 1), SyscallSucceedsWithValue(sizeof(kInput2) - 1)); // Get the first line. char line1[8] = {}; - ExpectReadable(slave_, sizeof(kInput1) - 1, line1); + ExpectReadable(replica_, sizeof(kInput1) - 1, line1); EXPECT_STREQ(line1, kInput1); // Get the second line. char line2[8] = {}; - ExpectReadable(slave_, sizeof(kInput2) - 1, line2); + ExpectReadable(replica_, sizeof(kInput2) - 1, line2); EXPECT_STREQ(line2, kInput2); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchNoncanonToCanonMultiline) { @@ -1115,21 +1117,21 @@ TEST_F(PtyTest, SwitchNoncanonToCanonMultiline) { constexpr char kExpected[] = "GO\nBLUE\n"; // Write both lines. - ASSERT_THAT(WriteFd(master_.get(), kInput1, sizeof(kInput1) - 1), + ASSERT_THAT(WriteFd(main_.get(), kInput1, sizeof(kInput1) - 1), SyscallSucceedsWithValue(sizeof(kInput1) - 1)); - ASSERT_THAT(WriteFd(master_.get(), kInput2, sizeof(kInput2) - 1), + ASSERT_THAT(WriteFd(main_.get(), kInput2, sizeof(kInput2) - 1), SyscallSucceedsWithValue(sizeof(kInput2) - 1)); ASSERT_NO_ERRNO( - WaitUntilReceived(slave_.get(), sizeof(kInput1) + sizeof(kInput2) - 2)); + WaitUntilReceived(replica_.get(), sizeof(kInput1) + sizeof(kInput2) - 2)); EnableCanonical(); // Get all together as one line. char line[9] = {}; - ExpectReadable(slave_, 8, line); + ExpectReadable(replica_, 8, line); EXPECT_STREQ(line, kExpected); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, SwitchTwiceMultiline) { @@ -1138,7 +1140,7 @@ TEST_F(PtyTest, SwitchTwiceMultiline) { // Write each line. for (const std::string& input : kInputs) { - ASSERT_THAT(WriteFd(master_.get(), input.c_str(), input.size()), + ASSERT_THAT(WriteFd(main_.get(), input.c_str(), input.size()), SyscallSucceedsWithValue(input.size())); } @@ -1146,32 +1148,32 @@ TEST_F(PtyTest, SwitchTwiceMultiline) { // All written characters have to make it into the input queue before // canonical mode is re-enabled. If the final '!' character hasn't been // enqueued before canonical mode is re-enabled, it won't be readable. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), kExpected.size())); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), kExpected.size())); EnableCanonical(); // Get all together as one line. char line[10] = {}; - ExpectReadable(slave_, 9, line); + ExpectReadable(replica_, 9, line); EXPECT_STREQ(line, kExpected.c_str()); - ExpectFinished(slave_); + ExpectFinished(replica_); } TEST_F(PtyTest, QueueSize) { // Write the line. constexpr char kInput1[] = "GO\n"; - ASSERT_THAT(WriteFd(master_.get(), kInput1, sizeof(kInput1) - 1), + ASSERT_THAT(WriteFd(main_.get(), kInput1, sizeof(kInput1) - 1), SyscallSucceedsWithValue(sizeof(kInput1) - 1)); - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), sizeof(kInput1) - 1)); + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), sizeof(kInput1) - 1)); // Ensure that writing more (beyond what is readable) does not impact the // readable size. char input[kMaxLineSize]; memset(input, 'M', kMaxLineSize); - ASSERT_THAT(WriteFd(master_.get(), input, kMaxLineSize), + ASSERT_THAT(WriteFd(main_.get(), input, kMaxLineSize), SyscallSucceedsWithValue(kMaxLineSize)); int inputBufSize = ASSERT_NO_ERRNO_AND_VALUE( - WaitUntilReceived(slave_.get(), sizeof(kInput1) - 1)); + WaitUntilReceived(replica_.get(), sizeof(kInput1) - 1)); EXPECT_EQ(inputBufSize, sizeof(kInput1) - 1); } @@ -1190,15 +1192,14 @@ TEST_F(PtyTest, PartialBadBuffer) { // Leave only one free byte in the buffer. char* bad_buffer = buf + kPageSize - 1; - // Write to the master. + // Write to the main. constexpr char kBuf[] = "hello\n"; constexpr size_t size = sizeof(kBuf) - 1; - EXPECT_THAT(WriteFd(master_.get(), kBuf, size), - SyscallSucceedsWithValue(size)); + EXPECT_THAT(WriteFd(main_.get(), kBuf, size), SyscallSucceedsWithValue(size)); - // Read from the slave into bad_buffer. - ASSERT_NO_ERRNO(WaitUntilReceived(slave_.get(), size)); - EXPECT_THAT(ReadFd(slave_.get(), bad_buffer, size), + // Read from the replica into bad_buffer. + ASSERT_NO_ERRNO(WaitUntilReceived(replica_.get(), size)); + EXPECT_THAT(ReadFd(replica_.get(), bad_buffer, size), SyscallFailsWithErrno(EFAULT)); EXPECT_THAT(munmap(addr, 2 * kPageSize), SyscallSucceeds()) << addr; @@ -1206,44 +1207,43 @@ TEST_F(PtyTest, PartialBadBuffer) { TEST_F(PtyTest, SimpleEcho) { constexpr char kInput[] = "Mr. Eko"; - EXPECT_THAT(WriteFd(master_.get(), kInput, strlen(kInput)), + EXPECT_THAT(WriteFd(main_.get(), kInput, strlen(kInput)), SyscallSucceedsWithValue(strlen(kInput))); char buf[100] = {}; - ExpectReadable(master_, strlen(kInput), buf); + ExpectReadable(main_, strlen(kInput), buf); EXPECT_STREQ(buf, kInput); - ExpectFinished(master_); + ExpectFinished(main_); } TEST_F(PtyTest, GetWindowSize) { struct winsize ws; - ASSERT_THAT(ioctl(slave_.get(), TIOCGWINSZ, &ws), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TIOCGWINSZ, &ws), SyscallSucceeds()); EXPECT_EQ(ws.ws_row, 0); EXPECT_EQ(ws.ws_col, 0); } -TEST_F(PtyTest, SetSlaveWindowSize) { +TEST_F(PtyTest, SetReplicaWindowSize) { constexpr uint16_t kRows = 343; constexpr uint16_t kCols = 2401; struct winsize ws = {.ws_row = kRows, .ws_col = kCols}; - ASSERT_THAT(ioctl(slave_.get(), TIOCSWINSZ, &ws), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TIOCSWINSZ, &ws), SyscallSucceeds()); struct winsize retrieved_ws = {}; - ASSERT_THAT(ioctl(master_.get(), TIOCGWINSZ, &retrieved_ws), - SyscallSucceeds()); + ASSERT_THAT(ioctl(main_.get(), TIOCGWINSZ, &retrieved_ws), SyscallSucceeds()); EXPECT_EQ(retrieved_ws.ws_row, kRows); EXPECT_EQ(retrieved_ws.ws_col, kCols); } -TEST_F(PtyTest, SetMasterWindowSize) { +TEST_F(PtyTest, SetMainWindowSize) { constexpr uint16_t kRows = 343; constexpr uint16_t kCols = 2401; struct winsize ws = {.ws_row = kRows, .ws_col = kCols}; - ASSERT_THAT(ioctl(master_.get(), TIOCSWINSZ, &ws), SyscallSucceeds()); + ASSERT_THAT(ioctl(main_.get(), TIOCSWINSZ, &ws), SyscallSucceeds()); struct winsize retrieved_ws = {}; - ASSERT_THAT(ioctl(slave_.get(), TIOCGWINSZ, &retrieved_ws), + ASSERT_THAT(ioctl(replica_.get(), TIOCGWINSZ, &retrieved_ws), SyscallSucceeds()); EXPECT_EQ(retrieved_ws.ws_row, kRows); EXPECT_EQ(retrieved_ws.ws_col, kCols); @@ -1252,8 +1252,8 @@ TEST_F(PtyTest, SetMasterWindowSize) { class JobControlTest : public ::testing::Test { protected: void SetUp() override { - master_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); - slave_ = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master_)); + main_ = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); + replica_ = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(main_)); // Make this a session leader, which also drops the controlling terminal. // In the gVisor test environment, this test will be run as the session @@ -1263,61 +1263,82 @@ class JobControlTest : public ::testing::Test { } } - // Master and slave ends of the PTY. Non-blocking. - FileDescriptor master_; - FileDescriptor slave_; + PosixError RunInChild(SubprocessCallback childFunc) { + pid_t child = fork(); + if (!child) { + childFunc(); + _exit(0); + } + int wstatus; + if (waitpid(child, &wstatus, 0) != child) { + return PosixError( + errno, absl::StrCat("child failed with wait status: ", wstatus)); + } + return PosixError(wstatus, "process returned"); + } + + // Main and replica ends of the PTY. Non-blocking. + FileDescriptor main_; + FileDescriptor replica_; }; -TEST_F(JobControlTest, SetTTYMaster) { - ASSERT_THAT(ioctl(master_.get(), TIOCSCTTY, 0), SyscallSucceeds()); +TEST_F(JobControlTest, SetTTYMain) { + auto res = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(main_.get(), TIOCSCTTY, 0)); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto res = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(ioctl(!replica_.get(), TIOCSCTTY, 0)); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetTTYNonLeader) { // Fork a process that won't be the session leader. - pid_t child = fork(); - if (!child) { - // We shouldn't be able to set the terminal. - TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 0)); - _exit(0); - } - - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); + auto res = + RunInChild([=]() { TEST_PCHECK(ioctl(replica_.get(), TIOCSCTTY, 0)); }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetTTYBadArg) { - // Despite the man page saying arg should be 0 here, Linux doesn't actually - // check. - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 1), SyscallSucceeds()); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); + auto res = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 1)); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, SetTTYDifferentSession) { SKIP_IF(ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_SYS_ADMIN))); - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Fork, join a new session, and try to steal the parent's controlling - // terminal, which should fail. - pid_t child = fork(); - if (!child) { + auto res = RunInChild([=]() { TEST_PCHECK(setsid() >= 0); - // We shouldn't be able to steal the terminal. - TEST_PCHECK(ioctl(slave_.get(), TIOCSCTTY, 1)); - _exit(0); - } + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 1)); - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); + // Fork, join a new session, and try to steal the parent's controlling + // terminal, which should fail. + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(setsid() >= 0); + // We shouldn't be able to steal the terminal. + TEST_PCHECK(ioctl(replica_.get(), TIOCSCTTY, 1)); + _exit(0); + } + + int gcwstatus; + TEST_PCHECK(waitpid(grandchild, &gcwstatus, 0) == grandchild); + TEST_PCHECK(gcwstatus == 0); + }); } TEST_F(JobControlTest, ReleaseTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TIOCSCTTY, 0), SyscallSucceeds()); // Make sure we're ignoring SIGHUP, which will be sent to this process once we // disconnect they TTY. @@ -1327,48 +1348,60 @@ TEST_F(JobControlTest, ReleaseTTY) { sigemptyset(&sa.sa_mask); struct sigaction old_sa; EXPECT_THAT(sigaction(SIGHUP, &sa, &old_sa), SyscallSucceeds()); - EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TIOCNOTTY), SyscallSucceeds()); EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds()); } TEST_F(JobControlTest, ReleaseUnsetTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); + ASSERT_THAT(ioctl(replica_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); } TEST_F(JobControlTest, ReleaseWrongTTY) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - ASSERT_THAT(ioctl(master_.get(), TIOCNOTTY), SyscallFailsWithErrno(ENOTTY)); + auto res = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); + TEST_PCHECK(ioctl(main_.get(), TIOCNOTTY) < 0 && errno == ENOTTY); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, ReleaseTTYNonLeader) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto ret = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); - pid_t child = fork(); - if (!child) { - TEST_PCHECK(!ioctl(slave_.get(), TIOCNOTTY)); - _exit(0); - } + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(!ioctl(replica_.get(), TIOCNOTTY)); + _exit(0); + } - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + TEST_PCHECK(wstatus == 0); + }); + ASSERT_NO_ERRNO(ret); } TEST_F(JobControlTest, ReleaseTTYDifferentSession) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - pid_t child = fork(); - if (!child) { - // Join a new session, then try to disconnect. + auto ret = RunInChild([=]() { TEST_PCHECK(setsid() >= 0); - TEST_PCHECK(ioctl(slave_.get(), TIOCNOTTY)); - _exit(0); - } - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_EQ(wstatus, 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); + + pid_t grandchild = fork(); + if (!grandchild) { + // Join a new session, then try to disconnect. + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(ioctl(replica_.get(), TIOCNOTTY)); + _exit(0); + } + + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + TEST_PCHECK(wstatus == 0); + }); + ASSERT_NO_ERRNO(ret); } // Used by the child process spawned in ReleaseTTYSignals to track received @@ -1387,7 +1420,7 @@ void sig_handler(int signum) { received |= signum; } // - Checks that thread 1 got both signals // - Checks that thread 2 didn't get any signals. TEST_F(JobControlTest, ReleaseTTYSignals) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + ASSERT_THAT(ioctl(replica_.get(), TIOCSCTTY, 0), SyscallSucceeds()); received = 0; struct sigaction sa = {}; @@ -1439,7 +1472,7 @@ TEST_F(JobControlTest, ReleaseTTYSignals) { // Release the controlling terminal, sending SIGHUP and SIGCONT to all other // processes in this process group. - EXPECT_THAT(ioctl(slave_.get(), TIOCNOTTY), SyscallSucceeds()); + EXPECT_THAT(ioctl(replica_.get(), TIOCNOTTY), SyscallSucceeds()); EXPECT_THAT(sigaction(SIGHUP, &old_sa, NULL), SyscallSucceeds()); @@ -1456,20 +1489,21 @@ TEST_F(JobControlTest, ReleaseTTYSignals) { } TEST_F(JobControlTest, GetForegroundProcessGroup) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - pid_t foreground_pgid; - pid_t pid; - ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid), - SyscallSucceeds()); - ASSERT_THAT(pid = getpid(), SyscallSucceeds()); - - ASSERT_EQ(foreground_pgid, pid); + auto res = RunInChild([=]() { + pid_t pid, foreground_pgid; + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 1)); + TEST_PCHECK(!ioctl(replica_.get(), TIOCGPGRP, &foreground_pgid)); + TEST_PCHECK((pid = getpid()) >= 0); + TEST_PCHECK(pid == foreground_pgid); + }); + ASSERT_NO_ERRNO(res); } TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) { // At this point there's no controlling terminal, so TIOCGPGRP should fail. pid_t foreground_pgid; - ASSERT_THAT(ioctl(slave_.get(), TIOCGPGRP, &foreground_pgid), + ASSERT_THAT(ioctl(replica_.get(), TIOCGPGRP, &foreground_pgid), SyscallFailsWithErrno(ENOTTY)); } @@ -1479,113 +1513,125 @@ TEST_F(JobControlTest, GetForegroundProcessGroupNonControlling) { // - sets that child as the foreground process group // - kills its child and sets itself as the foreground process group. TEST_F(JobControlTest, SetForegroundProcessGroup) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp. - struct sigaction sa = {}; - sa.sa_handler = SIG_IGN; - sa.sa_flags = 0; - sigemptyset(&sa.sa_mask); - sigaction(SIGTTOU, &sa, NULL); - - // Set ourself as the foreground process group. - ASSERT_THAT(tcsetpgrp(slave_.get(), getpgid(0)), SyscallSucceeds()); - - // Create a new process that just waits to be signaled. - pid_t child = fork(); - if (!child) { - TEST_PCHECK(!pause()); - // We should never reach this. - _exit(1); - } - - // Make the child its own process group, then make it the controlling process - // group of the terminal. - ASSERT_THAT(setpgid(child, child), SyscallSucceeds()); - ASSERT_THAT(tcsetpgrp(slave_.get(), child), SyscallSucceeds()); + auto res = RunInChild([=]() { + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); + + // Ignore SIGTTOU so that we don't stop ourself when calling tcsetpgrp. + struct sigaction sa = {}; + sa.sa_handler = SIG_IGN; + sa.sa_flags = 0; + sigemptyset(&sa.sa_mask); + sigaction(SIGTTOU, &sa, NULL); + + // Set ourself as the foreground process group. + TEST_PCHECK(!tcsetpgrp(replica_.get(), getpgid(0))); + + // Create a new process that just waits to be signaled. + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(!pause()); + // We should never reach this. + _exit(1); + } - // Sanity check - we're still the controlling session. - ASSERT_EQ(getsid(0), getsid(child)); + // Make the child its own process group, then make it the controlling + // process group of the terminal. + TEST_PCHECK(!setpgid(grandchild, grandchild)); + TEST_PCHECK(!tcsetpgrp(replica_.get(), grandchild)); - // Signal the child, wait for it to exit, then retake the terminal. - ASSERT_THAT(kill(child, SIGTERM), SyscallSucceeds()); - int wstatus; - ASSERT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - ASSERT_TRUE(WIFSIGNALED(wstatus)); - ASSERT_EQ(WTERMSIG(wstatus), SIGTERM); + // Sanity check - we're still the controlling session. + TEST_PCHECK(getsid(0) == getsid(grandchild)); - // Set ourself as the foreground process. - pid_t pgid; - ASSERT_THAT(pgid = getpgid(0), SyscallSucceeds()); - ASSERT_THAT(tcsetpgrp(slave_.get(), pgid), SyscallSucceeds()); + // Signal the child, wait for it to exit, then retake the terminal. + TEST_PCHECK(!kill(grandchild, SIGTERM)); + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + TEST_PCHECK(WIFSIGNALED(wstatus)); + TEST_PCHECK(WTERMSIG(wstatus) == SIGTERM); + + // Set ourself as the foreground process. + pid_t pgid; + TEST_PCHECK(pgid = getpgid(0) == 0); + TEST_PCHECK(!tcsetpgrp(replica_.get(), pgid)); + }); } TEST_F(JobControlTest, SetForegroundProcessGroupWrongTTY) { pid_t pid = getpid(); - ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid), + ASSERT_THAT(ioctl(replica_.get(), TIOCSPGRP, &pid), SyscallFailsWithErrno(ENOTTY)); } TEST_F(JobControlTest, SetForegroundProcessGroupNegPgid) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto ret = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); - pid_t pid = -1; - ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &pid), - SyscallFailsWithErrno(EINVAL)); + pid_t pid = -1; + TEST_PCHECK(ioctl(replica_.get(), TIOCSPGRP, &pid) && errno == EINVAL); + }); + ASSERT_NO_ERRNO(ret); } TEST_F(JobControlTest, SetForegroundProcessGroupEmptyProcessGroup) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); - - // Create a new process, put it in a new process group, make that group the - // foreground process group, then have the process wait. - pid_t child = fork(); - if (!child) { - TEST_PCHECK(!setpgid(0, 0)); - _exit(0); - } + auto ret = RunInChild([=]() { + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); + + // Create a new process, put it in a new process group, make that group the + // foreground process group, then have the process wait. + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(!setpgid(0, 0)); + _exit(0); + } - // Wait for the child to exit. - int wstatus; - EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - // The child's process group doesn't exist anymore - this should fail. - ASSERT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), - SyscallFailsWithErrno(ESRCH)); + // Wait for the child to exit. + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + // The child's process group doesn't exist anymore - this should fail. + TEST_PCHECK(ioctl(replica_.get(), TIOCSPGRP, &grandchild) != 0 && + errno == ESRCH); + }); } TEST_F(JobControlTest, SetForegroundProcessGroupDifferentSession) { - ASSERT_THAT(ioctl(slave_.get(), TIOCSCTTY, 0), SyscallSucceeds()); + auto ret = RunInChild([=]() { + TEST_PCHECK(setsid() >= 0); + TEST_PCHECK(!ioctl(replica_.get(), TIOCSCTTY, 0)); - int sync_setsid[2]; - int sync_exit[2]; - ASSERT_THAT(pipe(sync_setsid), SyscallSucceeds()); - ASSERT_THAT(pipe(sync_exit), SyscallSucceeds()); + int sync_setsid[2]; + int sync_exit[2]; + TEST_PCHECK(pipe(sync_setsid) >= 0); + TEST_PCHECK(pipe(sync_exit) >= 0); - // Create a new process and put it in a new session. - pid_t child = fork(); - if (!child) { - TEST_PCHECK(setsid() >= 0); - // Tell the parent we're in a new session. - char c = 'c'; - TEST_PCHECK(WriteFd(sync_setsid[1], &c, 1) == 1); - TEST_PCHECK(ReadFd(sync_exit[0], &c, 1) == 1); - _exit(0); - } + // Create a new process and put it in a new session. + pid_t grandchild = fork(); + if (!grandchild) { + TEST_PCHECK(setsid() >= 0); + // Tell the parent we're in a new session. + char c = 'c'; + TEST_PCHECK(WriteFd(sync_setsid[1], &c, 1) == 1); + TEST_PCHECK(ReadFd(sync_exit[0], &c, 1) == 1); + _exit(0); + } - // Wait for the child to tell us it's in a new session. - char c = 'c'; - ASSERT_THAT(ReadFd(sync_setsid[0], &c, 1), SyscallSucceedsWithValue(1)); + // Wait for the child to tell us it's in a new session. + char c = 'c'; + TEST_PCHECK(ReadFd(sync_setsid[0], &c, 1) == 1); - // Child is in a new session, so we can't make it the foregroup process group. - EXPECT_THAT(ioctl(slave_.get(), TIOCSPGRP, &child), - SyscallFailsWithErrno(EPERM)); + // Child is in a new session, so we can't make it the foregroup process + // group. + TEST_PCHECK(ioctl(replica_.get(), TIOCSPGRP, &grandchild) && + errno == EPERM); - EXPECT_THAT(WriteFd(sync_exit[1], &c, 1), SyscallSucceedsWithValue(1)); + TEST_PCHECK(WriteFd(sync_exit[1], &c, 1) == 1); - int wstatus; - EXPECT_THAT(waitpid(child, &wstatus, 0), SyscallSucceedsWithValue(child)); - EXPECT_TRUE(WIFEXITED(wstatus)); - EXPECT_EQ(WEXITSTATUS(wstatus), 0); + int wstatus; + TEST_PCHECK(waitpid(grandchild, &wstatus, 0) == grandchild); + TEST_PCHECK(WIFEXITED(wstatus)); + TEST_PCHECK(!WEXITSTATUS(wstatus)); + }); + ASSERT_NO_ERRNO(ret); } // Verify that we don't hang when creating a new session from an orphaned diff --git a/test/syscalls/linux/pty_root.cc b/test/syscalls/linux/pty_root.cc index 1d7dbefdb..a534cf0bb 100644 --- a/test/syscalls/linux/pty_root.cc +++ b/test/syscalls/linux/pty_root.cc @@ -48,12 +48,12 @@ TEST(JobControlRootTest, StealTTY) { ASSERT_THAT(setsid(), SyscallSucceeds()); } - FileDescriptor master = + FileDescriptor main = ASSERT_NO_ERRNO_AND_VALUE(Open("/dev/ptmx", O_RDWR | O_NONBLOCK)); - FileDescriptor slave = ASSERT_NO_ERRNO_AND_VALUE(OpenSlave(master)); + FileDescriptor replica = ASSERT_NO_ERRNO_AND_VALUE(OpenReplica(main)); - // Make slave the controlling terminal. - ASSERT_THAT(ioctl(slave.get(), TIOCSCTTY, 0), SyscallSucceeds()); + // Make replica the controlling terminal. + ASSERT_THAT(ioctl(replica.get(), TIOCSCTTY, 0), SyscallSucceeds()); // Fork, join a new session, and try to steal the parent's controlling // terminal, which should succeed when we have CAP_SYS_ADMIN and pass an arg @@ -62,9 +62,9 @@ TEST(JobControlRootTest, StealTTY) { if (!child) { ASSERT_THAT(setsid(), SyscallSucceeds()); // We shouldn't be able to steal the terminal with the wrong arg value. - TEST_PCHECK(ioctl(slave.get(), TIOCSCTTY, 0)); + TEST_PCHECK(ioctl(replica.get(), TIOCSCTTY, 0)); // We should be able to steal it if we are true root. - TEST_PCHECK(true_root == !ioctl(slave.get(), TIOCSCTTY, 1)); + TEST_PCHECK(true_root == !ioctl(replica.get(), TIOCSCTTY, 1)); _exit(0); } diff --git a/test/syscalls/linux/socket_inet_loopback.cc b/test/syscalls/linux/socket_inet_loopback.cc index 7c1d6a414..67893033c 100644 --- a/test/syscalls/linux/socket_inet_loopback.cc +++ b/test/syscalls/linux/socket_inet_loopback.cc @@ -97,13 +97,9 @@ TEST(BadSocketPairArgs, ValidateErrForBadCallsToSocketPair) { ASSERT_THAT(socketpair(AF_INET6, 0, 0, fd), SyscallFailsWithErrno(ESOCKTNOSUPPORT)); - // Invalid AF will return ENOAFSUPPORT or EPERM. - ASSERT_THAT(socketpair(AF_MAX, 0, 0, fd), - ::testing::AnyOf(SyscallFailsWithErrno(EAFNOSUPPORT), - SyscallFailsWithErrno(EPERM))); - ASSERT_THAT(socketpair(8675309, 0, 0, fd), - ::testing::AnyOf(SyscallFailsWithErrno(EAFNOSUPPORT), - SyscallFailsWithErrno(EPERM))); + // Invalid AF will fail. + ASSERT_THAT(socketpair(AF_MAX, 0, 0, fd), SyscallFails()); + ASSERT_THAT(socketpair(8675309, 0, 0, fd), SyscallFails()); } enum class Operation { diff --git a/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc index 6275b5aed..539a4ec55 100644 --- a/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc +++ b/test/syscalls/linux/socket_ipv6_udp_unbound_netlink.cc @@ -26,7 +26,10 @@ namespace testing { // Checks that the loopback interface considers itself bound to all IPs in an // associated subnet. TEST_P(IPv6UDPUnboundSocketNetlinkTest, JoinSubnet) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + // TODO(b/166440211): Only run this test on gvisor or remove if the loopback + // interface should not consider itself bound to all IPs in an IPv6 subnet. + SKIP_IF(!IsRunningOnGvisor() || + !ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); // Add an IP address to the loopback interface. Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); diff --git a/test/syscalls/linux/symlink.cc b/test/syscalls/linux/symlink.cc index aa1f32c85..f4f6ff490 100644 --- a/test/syscalls/linux/symlink.cc +++ b/test/syscalls/linux/symlink.cc @@ -231,12 +231,13 @@ TEST(SymlinkTest, PwriteToSymlink) { const int data_size = 10; const std::string data = std::string(data_size, 'a'); - EXPECT_THAT(pwrite64(fd, data.c_str(), data.size(), 0), SyscallSucceedsWithValue(data.size())); + EXPECT_THAT(pwrite64(fd, data.c_str(), data.size(), 0), + SyscallSucceedsWithValue(data.size())); ASSERT_THAT(close(fd), SyscallSucceeds()); ASSERT_THAT(fd = open(name.c_str(), O_RDONLY), SyscallSucceeds()); - char buf[data_size+1]; + char buf[data_size + 1]; EXPECT_THAT(pread64(fd, buf, data.size(), 0), SyscallSucceeds()); buf[data.size()] = '\0'; EXPECT_STREQ(buf, data.c_str()); diff --git a/test/syscalls/linux/write.cc b/test/syscalls/linux/write.cc index d611f3c5d..77bcfbb8a 100644 --- a/test/syscalls/linux/write.cc +++ b/test/syscalls/linux/write.cc @@ -135,7 +135,8 @@ TEST_F(WriteTest, WriteExceedsRLimit) { TEST_F(WriteTest, WriteIncrementOffset) { TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); int fd = f.get(); EXPECT_THAT(WriteBytes(fd, 0), SyscallSucceedsWithValue(0)); @@ -143,7 +144,8 @@ TEST_F(WriteTest, WriteIncrementOffset) { const int bytes_total = 1024; - EXPECT_THAT(WriteBytes(fd, bytes_total), SyscallSucceedsWithValue(bytes_total)); + EXPECT_THAT(WriteBytes(fd, bytes_total), + SyscallSucceedsWithValue(bytes_total)); EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(bytes_total)); } @@ -151,56 +153,68 @@ TEST_F(WriteTest, WriteIncrementOffsetSeek) { const std::string data = "hello world\n"; TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), data, TempPath::kDefaultFileMode)); - FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); int fd = f.get(); - const int seek_offset = data.size()/2; - ASSERT_THAT(lseek(fd, seek_offset, SEEK_SET), SyscallSucceedsWithValue(seek_offset)); + const int seek_offset = data.size() / 2; + ASSERT_THAT(lseek(fd, seek_offset, SEEK_SET), + SyscallSucceedsWithValue(seek_offset)); const int write_bytes = 512; - EXPECT_THAT(WriteBytes(fd, write_bytes), SyscallSucceedsWithValue(write_bytes)); - EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(seek_offset+write_bytes)); + EXPECT_THAT(WriteBytes(fd, write_bytes), + SyscallSucceedsWithValue(write_bytes)); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), + SyscallSucceedsWithValue(seek_offset + write_bytes)); } TEST_F(WriteTest, WriteIncrementOffsetAppend) { const std::string data = "hello world\n"; TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), data, TempPath::kDefaultFileMode)); - FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(),O_WRONLY | O_APPEND)); + FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE( + Open(tmpfile.path().c_str(), O_WRONLY | O_APPEND)); int fd = f.get(); EXPECT_THAT(WriteBytes(fd, 1024), SyscallSucceedsWithValue(1024)); - EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(data.size()+1024)); + EXPECT_THAT(lseek(fd, 0, SEEK_CUR), + SyscallSucceedsWithValue(data.size() + 1024)); } TEST_F(WriteTest, WriteIncrementOffsetEOF) { const std::string data = "hello world\n"; const TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( GetAbsoluteTestTmpdir(), data, TempPath::kDefaultFileMode)); - FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); int fd = f.get(); EXPECT_THAT(lseek(fd, 0, SEEK_END), SyscallSucceedsWithValue(data.size())); EXPECT_THAT(WriteBytes(fd, 1024), SyscallSucceedsWithValue(1024)); - EXPECT_THAT(lseek(fd, 0, SEEK_END), SyscallSucceedsWithValue(data.size()+1024)); + EXPECT_THAT(lseek(fd, 0, SEEK_END), + SyscallSucceedsWithValue(data.size() + 1024)); } TEST_F(WriteTest, PwriteNoChangeOffset) { TempPath tmpfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); - FileDescriptor f = ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); + FileDescriptor f = + ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfile.path().c_str(), O_WRONLY)); int fd = f.get(); const std::string data = "hello world\n"; - EXPECT_THAT(pwrite(fd, data.data(), data.size(), 0), SyscallSucceedsWithValue(data.size())); + EXPECT_THAT(pwrite(fd, data.data(), data.size(), 0), + SyscallSucceedsWithValue(data.size())); EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(0)); const int bytes_total = 1024; - ASSERT_THAT(WriteBytes(fd, bytes_total), SyscallSucceedsWithValue(bytes_total)); + ASSERT_THAT(WriteBytes(fd, bytes_total), + SyscallSucceedsWithValue(bytes_total)); ASSERT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(bytes_total)); - EXPECT_THAT(pwrite(fd, data.data(), data.size(), bytes_total), SyscallSucceedsWithValue(data.size())); + EXPECT_THAT(pwrite(fd, data.data(), data.size(), bytes_total), + SyscallSucceedsWithValue(data.size())); EXPECT_THAT(lseek(fd, 0, SEEK_CUR), SyscallSucceedsWithValue(bytes_total)); } diff --git a/test/util/pty_util.cc b/test/util/pty_util.cc index c01f916aa..5fa622922 100644 --- a/test/util/pty_util.cc +++ b/test/util/pty_util.cc @@ -23,25 +23,25 @@ namespace gvisor { namespace testing { -PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master) { - PosixErrorOr<int> n = SlaveID(master); +PosixErrorOr<FileDescriptor> OpenReplica(const FileDescriptor& main) { + PosixErrorOr<int> n = ReplicaID(main); if (!n.ok()) { return PosixErrorOr<FileDescriptor>(n.error()); } return Open(absl::StrCat("/dev/pts/", n.ValueOrDie()), O_RDWR | O_NONBLOCK); } -PosixErrorOr<int> SlaveID(const FileDescriptor& master) { +PosixErrorOr<int> ReplicaID(const FileDescriptor& main) { // Get pty index. int n; - int ret = ioctl(master.get(), TIOCGPTN, &n); + int ret = ioctl(main.get(), TIOCGPTN, &n); if (ret < 0) { return PosixError(errno, "ioctl(TIOCGPTN) failed"); } // Unlock pts. int unlock = 0; - ret = ioctl(master.get(), TIOCSPTLCK, &unlock); + ret = ioctl(main.get(), TIOCSPTLCK, &unlock); if (ret < 0) { return PosixError(errno, "ioctl(TIOSPTLCK) failed"); } diff --git a/test/util/pty_util.h b/test/util/pty_util.h index 0722da379..dff6adab5 100644 --- a/test/util/pty_util.h +++ b/test/util/pty_util.h @@ -21,11 +21,11 @@ namespace gvisor { namespace testing { -// Opens the slave end of the passed master as R/W and nonblocking. -PosixErrorOr<FileDescriptor> OpenSlave(const FileDescriptor& master); +// Opens the replica end of the passed main as R/W and nonblocking. +PosixErrorOr<FileDescriptor> OpenReplica(const FileDescriptor& main); -// Get the number of the slave end of the master. -PosixErrorOr<int> SlaveID(const FileDescriptor& master); +// Get the number of the replica end of the main. +PosixErrorOr<int> ReplicaID(const FileDescriptor& main); } // namespace testing } // namespace gvisor |