diff options
Diffstat (limited to 'pkg')
59 files changed, 658 insertions, 572 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 29ead20d0..38288bdb7 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -79,6 +79,7 @@ go_library( "//pkg/abi", "//pkg/bits", "//pkg/context", + "//pkg/hostarch", "//pkg/marshal", "//pkg/marshal/primitive", ], diff --git a/pkg/abi/linux/signal.go b/pkg/abi/linux/signal.go index 6ca57ffbb..bbf7f6580 100644 --- a/pkg/abi/linux/signal.go +++ b/pkg/abi/linux/signal.go @@ -16,6 +16,7 @@ package linux import ( "gvisor.dev/gvisor/pkg/bits" + "gvisor.dev/gvisor/pkg/hostarch" ) const ( @@ -165,7 +166,7 @@ const ( SIG_IGN = 1 ) -// Signal action flags for rt_sigaction(2), from uapi/asm-generic/signal.h +// Signal action flags for rt_sigaction(2), from uapi/asm-generic/signal.h. const ( SA_NOCLDSTOP = 0x00000001 SA_NOCLDWAIT = 0x00000002 @@ -179,6 +180,12 @@ const ( SA_ONESHOT = SA_RESETHAND ) +// Signal stack flags for signalstack(2), from include/uapi/linux/signal.h. +const ( + SS_ONSTACK = 1 + SS_DISABLE = 2 +) + // Signal info types. const ( SI_MASK = 0xffff0000 @@ -227,6 +234,48 @@ type Sigevent struct { UnRemainder [44]byte } +// LINT.IfChange + +// SigAction represents struct sigaction. +// +// +marshal +// +stateify savable +type SigAction struct { + Handler uint64 + Flags uint64 + Restorer uint64 + Mask SignalSet +} + +// LINT.ThenChange(../../safecopy/safecopy_unsafe.go) + +// SignalStack represents information about a user stack, and is equivalent to +// stack_t. +// +// +marshal +// +stateify savable +type SignalStack struct { + Addr uint64 + Flags uint32 + _ uint32 + Size uint64 +} + +// Contains checks if the stack pointer is within this stack. +func (s *SignalStack) Contains(sp hostarch.Addr) bool { + return hostarch.Addr(s.Addr) < sp && sp <= hostarch.Addr(s.Addr+s.Size) +} + +// Top returns the stack's top address. +func (s *SignalStack) Top() hostarch.Addr { + return hostarch.Addr(s.Addr + s.Size) +} + +// IsEnabled returns true iff this signal stack is marked as enabled. +func (s *SignalStack) IsEnabled() bool { + return s.Flags&SS_DISABLE == 0 +} + // Possible values for Sigevent.Notify, aka struct sigevent::sigev_notify. const ( SIGEV_SIGNAL = 0 diff --git a/pkg/metric/metric.go b/pkg/metric/metric.go index fdeee3a5f..4829ae7ce 100644 --- a/pkg/metric/metric.go +++ b/pkg/metric/metric.go @@ -36,17 +36,22 @@ var ( // new metric after initialization. ErrInitializationDone = errors.New("metric cannot be created after initialization is complete") - // createdSentryMetrics indicates that the sentry metrics are created. - createdSentryMetrics = false - // WeirdnessMetric is a metric with fields created to track the number // of weird occurrences such as time fallback, partial_result, vsyscall // count, watchdog startup timeouts and stuck tasks. - WeirdnessMetric *Uint64Metric + WeirdnessMetric = MustCreateNewUint64Metric("/weirdness", true /* sync */, "Increment for weird occurrences of problems such as time fallback, partial result, vsyscalls invoked in the sandbox, watchdog startup timeouts and stuck tasks.", + Field{ + name: "weirdness_type", + allowedValues: []string{"time_fallback", "partial_result", "vsyscall_count", "watchdog_stuck_startup", "watchdog_stuck_tasks"}, + }) // SuspiciousOperationsMetric is a metric with fields created to detect // operations such as opening an executable file to write from a gofer. - SuspiciousOperationsMetric *Uint64Metric + SuspiciousOperationsMetric = MustCreateNewUint64Metric("/suspicious_operations", true /* sync */, "Increment for suspicious operations such as opening an executable file to write from a gofer.", + Field{ + name: "operation_type", + allowedValues: []string{"opened_write_execute_file"}, + }) ) // Uint64Metric encapsulates a uint64 that represents some kind of metric to be @@ -84,17 +89,21 @@ var ( // Precondition: // * All metrics are registered. // * Initialize/Disable has not been called. -func Initialize() { +func Initialize() error { if initialized { - panic("Initialize/Disable called more than once") + return errors.New("metric.Initialize called after metric.Initialize or metric.Disable") } - initialized = true m := pb.MetricRegistration{} for _, v := range allMetrics.m { m.Metrics = append(m.Metrics, v.metadata) } - eventchannel.Emit(&m) + if err := eventchannel.Emit(&m); err != nil { + return fmt.Errorf("unable to emit metric initialize event: %w", err) + } + + initialized = true + return nil } // Disable sends an empty metric registration event over the event channel, @@ -103,16 +112,18 @@ func Initialize() { // Precondition: // * All metrics are registered. // * Initialize/Disable has not been called. -func Disable() { +func Disable() error { if initialized { - panic("Initialize/Disable called more than once") + return errors.New("metric.Disable called after metric.Initialize or metric.Disable") } - initialized = true m := pb.MetricRegistration{} if err := eventchannel.Emit(&m); err != nil { - panic("unable to emit metric disable event: " + err.Error()) + return fmt.Errorf("unable to emit metric disable event: %w", err) } + + initialized = true + return nil } type customUint64Metric struct { @@ -165,8 +176,8 @@ func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.Met } // Metrics can exist without fields. - if len(fields) > 1 { - panic("Sentry metrics support at most one field") + if l := len(fields); l > 1 { + return fmt.Errorf("%d fields provided, must be <= 1", l) } for _, field := range fields { @@ -182,7 +193,7 @@ func RegisterCustomUint64Metric(name string, cumulative, sync bool, units pb.Met // without fields and panics if it returns an error. func MustRegisterCustomUint64Metric(name string, cumulative, sync bool, description string, value func(...string) uint64, fields ...Field) { if err := RegisterCustomUint64Metric(name, cumulative, sync, pb.MetricMetadata_UNITS_NONE, description, value, fields...); err != nil { - panic(fmt.Sprintf("Unable to register metric %q: %v", name, err)) + panic(fmt.Sprintf("Unable to register metric %q: %s", name, err)) } } @@ -209,7 +220,7 @@ func NewUint64Metric(name string, sync bool, units pb.MetricMetadata_Units, desc func MustCreateNewUint64Metric(name string, sync bool, description string, fields ...Field) *Uint64Metric { m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NONE, description, fields...) if err != nil { - panic(fmt.Sprintf("Unable to create metric %q: %v", name, err)) + panic(fmt.Sprintf("Unable to create metric %q: %s", name, err)) } return m } @@ -219,7 +230,7 @@ func MustCreateNewUint64Metric(name string, sync bool, description string, field func MustCreateNewUint64NanosecondsMetric(name string, sync bool, description string) *Uint64Metric { m, err := NewUint64Metric(name, sync, pb.MetricMetadata_UNITS_NANOSECONDS, description) if err != nil { - panic(fmt.Sprintf("Unable to create metric %q: %v", name, err)) + panic(fmt.Sprintf("Unable to create metric %q: %s", name, err)) } return m } @@ -354,7 +365,7 @@ func EmitMetricUpdate() { m.Metrics = append(m.Metrics, &pb.MetricValue{ Name: k, - Value: &pb.MetricValue_Uint64Value{t}, + Value: &pb.MetricValue_Uint64Value{Uint64Value: t}, }) case map[string]uint64: for fieldValue, metricValue := range t { @@ -369,7 +380,7 @@ func EmitMetricUpdate() { m.Metrics = append(m.Metrics, &pb.MetricValue{ Name: k, FieldValues: []string{fieldValue}, - Value: &pb.MetricValue_Uint64Value{metricValue}, + Value: &pb.MetricValue_Uint64Value{Uint64Value: metricValue}, }) } } @@ -390,26 +401,7 @@ func EmitMetricUpdate() { } } - eventchannel.Emit(&m) -} - -// CreateSentryMetrics creates the sentry metrics during kernel initialization. -func CreateSentryMetrics() { - if createdSentryMetrics { - return + if err := eventchannel.Emit(&m); err != nil { + log.Warningf("Unable to emit metrics: %s", err) } - - createdSentryMetrics = true - - WeirdnessMetric = MustCreateNewUint64Metric("/weirdness", true /* sync */, "Increment for weird occurrences of problems such as time fallback, partial result, vsyscalls invoked in the sandbox, watchdog startup timeouts and stuck tasks.", - Field{ - name: "weirdness_type", - allowedValues: []string{"time_fallback", "partial_result", "vsyscall_count", "watchdog_stuck_startup", "watchdog_stuck_tasks"}, - }) - - SuspiciousOperationsMetric = MustCreateNewUint64Metric("/suspicious_operations", true /* sync */, "Increment for suspicious operations such as opening an executable file to write from a gofer.", - Field{ - name: "operation_type", - allowedValues: []string{"opened_write_execute_file"}, - }) } diff --git a/pkg/metric/metric_test.go b/pkg/metric/metric_test.go index c71dfd460..1b4a9e73a 100644 --- a/pkg/metric/metric_test.go +++ b/pkg/metric/metric_test.go @@ -48,6 +48,8 @@ func (s *sliceEmitter) Reset() { var emitter sliceEmitter func init() { + reset() + eventchannel.AddEmitter(&emitter) } @@ -77,7 +79,9 @@ func TestInitialize(t *testing.T) { t.Fatalf("NewUint64Metric got err %v want nil", err) } - Initialize() + if err := Initialize(); err != nil { + t.Fatalf("Initialize(): %s", err) + } if len(emitter) != 1 { t.Fatalf("Initialize emitted %d events want 1", len(emitter)) @@ -149,7 +153,9 @@ func TestDisable(t *testing.T) { t.Fatalf("NewUint64Metric got err %v want nil", err) } - Disable() + if err := Disable(); err != nil { + t.Fatalf("Disable(): %s", err) + } if len(emitter) != 1 { t.Fatalf("Initialize emitted %d events want 1", len(emitter)) @@ -178,7 +184,9 @@ func TestEmitMetricUpdate(t *testing.T) { t.Fatalf("NewUint64Metric got err %v want nil", err) } - Initialize() + if err := Initialize(); err != nil { + t.Fatalf("Initialize(): %s", err) + } // Don't care about the registration metrics. emitter.Reset() @@ -270,7 +278,9 @@ func TestEmitMetricUpdateWithFields(t *testing.T) { t.Fatalf("NewUint64Metric got err %v want nil", err) } - Initialize() + if err := Initialize(); err != nil { + t.Fatalf("Initialize(): %s", err) + } // Don't care about the registration metrics. emitter.Reset() diff --git a/pkg/safecopy/safecopy_unsafe.go b/pkg/safecopy/safecopy_unsafe.go index efbc2ddc1..3ec73f296 100644 --- a/pkg/safecopy/safecopy_unsafe.go +++ b/pkg/safecopy/safecopy_unsafe.go @@ -342,6 +342,9 @@ func errorFromFaultSignal(addr uintptr, sig int32) error { // handler however, and if this is function is being used externally then the // same courtesy is expected. func ReplaceSignalHandler(sig unix.Signal, handler uintptr, previous *uintptr) error { + // TODO(gvisor.dev/issue/6160): This struct is the same as linux.SigAction. + // Once the usermem dependency is removed from primitive, delete this replica + // and remove IFTTT comments in abi/linux/signal.go. var sa struct { handler uintptr flags uint64 diff --git a/pkg/seccomp/seccomp.go b/pkg/seccomp/seccomp.go index daea51c4d..8ffa1db37 100644 --- a/pkg/seccomp/seccomp.go +++ b/pkg/seccomp/seccomp.go @@ -36,14 +36,10 @@ const ( // Install generates BPF code based on the set of syscalls provided. It only // allows syscalls that conform to the specification. Syscalls that violate the -// specification will trigger RET_KILL_PROCESS, except for the cases below. -// -// RET_TRAP is used in violations, instead of RET_KILL_PROCESS, in the -// following cases: -// 1. Kernel doesn't support RET_KILL_PROCESS: RET_KILL_THREAD only kills the -// offending thread and often keeps the sentry hanging. -// 2. Debug: RET_TRAP generates a panic followed by a stack trace which is -// much easier to debug then RET_KILL_PROCESS which can't be caught. +// specification will trigger RET_KILL_PROCESS. If RET_KILL_PROCESS is not +// supported, violations will trigger RET_TRAP instead. RET_KILL_THREAD is not +// used because it only kills the offending thread and often keeps the sentry +// hanging. // // Be aware that RET_TRAP sends SIGSYS to the process and it may be ignored, // making it possible for the process to continue running after a violation. diff --git a/pkg/sentry/arch/BUILD b/pkg/sentry/arch/BUILD index c9c52530d..1f467b7c9 100644 --- a/pkg/sentry/arch/BUILD +++ b/pkg/sentry/arch/BUILD @@ -15,11 +15,9 @@ go_library( "arch_x86_impl.go", "auxv.go", "signal.go", - "signal_act.go", "signal_amd64.go", "signal_arm64.go", "signal_info.go", - "signal_stack.go", "stack.go", "stack_unsafe.go", "syscalls_amd64.go", diff --git a/pkg/sentry/arch/arch.go b/pkg/sentry/arch/arch.go index 290863ee6..d765d8374 100644 --- a/pkg/sentry/arch/arch.go +++ b/pkg/sentry/arch/arch.go @@ -134,21 +134,13 @@ type Context interface { // RegisterMap returns a map of all registers. RegisterMap() (map[string]uintptr, error) - // NewSignalAct returns a new object that is equivalent to struct sigaction - // in the guest architecture. - NewSignalAct() NativeSignalAct - - // NewSignalStack returns a new object that is equivalent to stack_t in the - // guest architecture. - NewSignalStack() NativeSignalStack - // SignalSetup modifies the context in preparation for handling the // given signal. // // st is the stack where the signal handler frame should be // constructed. // - // act is the SignalAct that specifies how this signal is being + // act is the SigAction that specifies how this signal is being // handled. // // info is the SignalInfo of the signal being delivered. @@ -157,7 +149,7 @@ type Context interface { // stack is not going to be used). // // sigset is the signal mask before entering the signal handler. - SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error + SignalSetup(st *Stack, act *linux.SigAction, info *SignalInfo, alt *linux.SignalStack, sigset linux.SignalSet) error // SignalRestore restores context after returning from a signal // handler. @@ -167,7 +159,7 @@ type Context interface { // rt is true if SignalRestore is being entered from rt_sigreturn and // false if SignalRestore is being entered from sigreturn. // SignalRestore returns the thread's new signal mask. - SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) + SignalRestore(st *Stack, rt bool) (linux.SignalSet, linux.SignalStack, error) // CPUIDEmulate emulates a CPUID instruction according to current register state. CPUIDEmulate(l log.Logger) diff --git a/pkg/sentry/arch/signal.go b/pkg/sentry/arch/signal.go index 67d7edf68..b9fd14d10 100644 --- a/pkg/sentry/arch/signal.go +++ b/pkg/sentry/arch/signal.go @@ -19,50 +19,6 @@ import ( "gvisor.dev/gvisor/pkg/hostarch" ) -// SignalAct represents the action that should be taken when a signal is -// delivered, and is equivalent to struct sigaction. -// -// +marshal -// +stateify savable -type SignalAct struct { - Handler uint64 - Flags uint64 - Restorer uint64 // Only used on amd64. - Mask linux.SignalSet -} - -// SerializeFrom implements NativeSignalAct.SerializeFrom. -func (s *SignalAct) SerializeFrom(other *SignalAct) { - *s = *other -} - -// DeserializeTo implements NativeSignalAct.DeserializeTo. -func (s *SignalAct) DeserializeTo(other *SignalAct) { - *other = *s -} - -// SignalStack represents information about a user stack, and is equivalent to -// stack_t. -// -// +marshal -// +stateify savable -type SignalStack struct { - Addr uint64 - Flags uint32 - _ uint32 - Size uint64 -} - -// SerializeFrom implements NativeSignalStack.SerializeFrom. -func (s *SignalStack) SerializeFrom(other *SignalStack) { - *s = *other -} - -// DeserializeTo implements NativeSignalStack.DeserializeTo. -func (s *SignalStack) DeserializeTo(other *SignalStack) { - *other = *s -} - // SignalInfo represents information about a signal being delivered, and is // equivalent to struct siginfo in linux kernel(linux/include/uapi/asm-generic/siginfo.h). // diff --git a/pkg/sentry/arch/signal_act.go b/pkg/sentry/arch/signal_act.go deleted file mode 100644 index d3e2324a8..000000000 --- a/pkg/sentry/arch/signal_act.go +++ /dev/null @@ -1,83 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package arch - -import "gvisor.dev/gvisor/pkg/marshal" - -// Special values for SignalAct.Handler. -const ( - // SignalActDefault is SIG_DFL and specifies that the default behavior for - // a signal should be taken. - SignalActDefault = 0 - - // SignalActIgnore is SIG_IGN and specifies that a signal should be - // ignored. - SignalActIgnore = 1 -) - -// Available signal flags. -const ( - SignalFlagNoCldStop = 0x00000001 - SignalFlagNoCldWait = 0x00000002 - SignalFlagSigInfo = 0x00000004 - SignalFlagRestorer = 0x04000000 - SignalFlagOnStack = 0x08000000 - SignalFlagRestart = 0x10000000 - SignalFlagInterrupt = 0x20000000 - SignalFlagNoDefer = 0x40000000 - SignalFlagResetHandler = 0x80000000 -) - -// IsSigInfo returns true iff this handle expects siginfo. -func (s SignalAct) IsSigInfo() bool { - return s.Flags&SignalFlagSigInfo != 0 -} - -// IsNoDefer returns true iff this SignalAct has the NoDefer flag set. -func (s SignalAct) IsNoDefer() bool { - return s.Flags&SignalFlagNoDefer != 0 -} - -// IsRestart returns true iff this SignalAct has the Restart flag set. -func (s SignalAct) IsRestart() bool { - return s.Flags&SignalFlagRestart != 0 -} - -// IsResetHandler returns true iff this SignalAct has the ResetHandler flag set. -func (s SignalAct) IsResetHandler() bool { - return s.Flags&SignalFlagResetHandler != 0 -} - -// IsOnStack returns true iff this SignalAct has the OnStack flag set. -func (s SignalAct) IsOnStack() bool { - return s.Flags&SignalFlagOnStack != 0 -} - -// HasRestorer returns true iff this SignalAct has the Restorer flag set. -func (s SignalAct) HasRestorer() bool { - return s.Flags&SignalFlagRestorer != 0 -} - -// NativeSignalAct is a type that is equivalent to struct sigaction in the -// guest architecture. -type NativeSignalAct interface { - marshal.Marshallable - - // SerializeFrom copies the data in the host SignalAct s into this object. - SerializeFrom(s *SignalAct) - - // DeserializeTo copies the data in this object into the host SignalAct s. - DeserializeTo(s *SignalAct) -} diff --git a/pkg/sentry/arch/signal_amd64.go b/pkg/sentry/arch/signal_amd64.go index 082ed92b1..fa74a2551 100644 --- a/pkg/sentry/arch/signal_amd64.go +++ b/pkg/sentry/arch/signal_amd64.go @@ -76,21 +76,11 @@ const ( type UContext64 struct { Flags uint64 Link uint64 - Stack SignalStack + Stack linux.SignalStack MContext SignalContext64 Sigset linux.SignalSet } -// NewSignalAct implements Context.NewSignalAct. -func (c *context64) NewSignalAct() NativeSignalAct { - return &SignalAct{} -} - -// NewSignalStack implements Context.NewSignalStack. -func (c *context64) NewSignalStack() NativeSignalStack { - return &SignalStack{} -} - // From Linux 'arch/x86/include/uapi/asm/sigcontext.h' the following is the // size of the magic cookie at the end of the xsave frame. // @@ -110,7 +100,7 @@ func (c *context64) fpuFrameSize() (size int, useXsave bool) { // SignalSetup implements Context.SignalSetup. (Compare to Linux's // arch/x86/kernel/signal.c:__setup_rt_frame().) -func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error { +func (c *context64) SignalSetup(st *Stack, act *linux.SigAction, info *SignalInfo, alt *linux.SignalStack, sigset linux.SignalSet) error { sp := st.Bottom // "The 128-byte area beyond the location pointed to by %rsp is considered @@ -187,7 +177,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt // Prior to proceeding, figure out if the frame will exhaust the range // for the signal stack. This is not allowed, and should immediately // force signal delivery (reverting to the default handler). - if act.IsOnStack() && alt.IsEnabled() && !alt.Contains(frameBottom) { + if act.Flags&linux.SA_ONSTACK != 0 && alt.IsEnabled() && !alt.Contains(frameBottom) { return unix.EFAULT } @@ -203,7 +193,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt return err } ucAddr := st.Bottom - if act.HasRestorer() { + if act.Flags&linux.SA_RESTORER != 0 { // Push the restorer return address. // Note that this doesn't need to be popped. if _, err := primitive.CopyUint64Out(st, StackBottomMagic, act.Restorer); err != nil { @@ -237,15 +227,15 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt // SignalRestore implements Context.SignalRestore. (Compare to Linux's // arch/x86/kernel/signal.c:sys_rt_sigreturn().) -func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) { +func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, linux.SignalStack, error) { // Copy out the stack frame. var uc UContext64 if _, err := uc.CopyIn(st, StackBottomMagic); err != nil { - return 0, SignalStack{}, err + return 0, linux.SignalStack{}, err } var info SignalInfo if _, err := info.CopyIn(st, StackBottomMagic); err != nil { - return 0, SignalStack{}, err + return 0, linux.SignalStack{}, err } // Restore registers. diff --git a/pkg/sentry/arch/signal_arm64.go b/pkg/sentry/arch/signal_arm64.go index da71fb873..3d632e7fd 100644 --- a/pkg/sentry/arch/signal_arm64.go +++ b/pkg/sentry/arch/signal_arm64.go @@ -61,7 +61,7 @@ type FpsimdContext struct { type UContext64 struct { Flags uint64 Link uint64 - Stack SignalStack + Stack linux.SignalStack Sigset linux.SignalSet // glibc uses a 1024-bit sigset_t _pad [120]byte // (1024 - 64) / 8 = 120 @@ -71,18 +71,8 @@ type UContext64 struct { MContext SignalContext64 } -// NewSignalAct implements Context.NewSignalAct. -func (c *context64) NewSignalAct() NativeSignalAct { - return &SignalAct{} -} - -// NewSignalStack implements Context.NewSignalStack. -func (c *context64) NewSignalStack() NativeSignalStack { - return &SignalStack{} -} - // SignalSetup implements Context.SignalSetup. -func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt *SignalStack, sigset linux.SignalSet) error { +func (c *context64) SignalSetup(st *Stack, act *linux.SigAction, info *SignalInfo, alt *linux.SignalStack, sigset linux.SignalSet) error { sp := st.Bottom // Construct the UContext64 now since we need its size. @@ -114,7 +104,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt // Prior to proceeding, figure out if the frame will exhaust the range // for the signal stack. This is not allowed, and should immediately // force signal delivery (reverting to the default handler). - if act.IsOnStack() && alt.IsEnabled() && !alt.Contains(frameBottom) { + if act.Flags&linux.SA_ONSTACK != 0 && alt.IsEnabled() && !alt.Contains(frameBottom) { return unix.EFAULT } @@ -137,7 +127,7 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt c.Regs.Regs[0] = uint64(info.Signo) c.Regs.Regs[1] = uint64(infoAddr) c.Regs.Regs[2] = uint64(ucAddr) - c.Regs.Regs[30] = uint64(act.Restorer) + c.Regs.Regs[30] = act.Restorer // Save the thread's floating point state. c.sigFPState = append(c.sigFPState, c.fpState) @@ -147,15 +137,15 @@ func (c *context64) SignalSetup(st *Stack, act *SignalAct, info *SignalInfo, alt } // SignalRestore implements Context.SignalRestore. -func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, SignalStack, error) { +func (c *context64) SignalRestore(st *Stack, rt bool) (linux.SignalSet, linux.SignalStack, error) { // Copy out the stack frame. var uc UContext64 if _, err := uc.CopyIn(st, StackBottomMagic); err != nil { - return 0, SignalStack{}, err + return 0, linux.SignalStack{}, err } var info SignalInfo if _, err := info.CopyIn(st, StackBottomMagic); err != nil { - return 0, SignalStack{}, err + return 0, linux.SignalStack{}, err } // Restore registers. diff --git a/pkg/sentry/arch/signal_stack.go b/pkg/sentry/arch/signal_stack.go deleted file mode 100644 index c732c7503..000000000 --- a/pkg/sentry/arch/signal_stack.go +++ /dev/null @@ -1,68 +0,0 @@ -// Copyright 2018 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// +build 386 amd64 arm64 - -package arch - -import ( - "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/marshal" -) - -const ( - // SignalStackFlagOnStack is possible set on return from getaltstack, - // in order to indicate that the thread is currently on the alt stack. - SignalStackFlagOnStack = 1 - - // SignalStackFlagDisable is a flag to indicate the stack is disabled. - SignalStackFlagDisable = 2 -) - -// IsEnabled returns true iff this signal stack is marked as enabled. -func (s SignalStack) IsEnabled() bool { - return s.Flags&SignalStackFlagDisable == 0 -} - -// Top returns the stack's top address. -func (s SignalStack) Top() hostarch.Addr { - return hostarch.Addr(s.Addr + s.Size) -} - -// SetOnStack marks this signal stack as in use. -// -// Note that there is no corresponding ClearOnStack, and that this should only -// be called on copies that are serialized to userspace. -func (s *SignalStack) SetOnStack() { - s.Flags |= SignalStackFlagOnStack -} - -// Contains checks if the stack pointer is within this stack. -func (s *SignalStack) Contains(sp hostarch.Addr) bool { - return hostarch.Addr(s.Addr) < sp && sp <= hostarch.Addr(s.Addr+s.Size) -} - -// NativeSignalStack is a type that is equivalent to stack_t in the guest -// architecture. -type NativeSignalStack interface { - marshal.Marshallable - - // SerializeFrom copies the data in the host SignalStack s into this - // object. - SerializeFrom(s *SignalStack) - - // DeserializeTo copies the data in this object into the host SignalStack - // s. - DeserializeTo(s *SignalStack) -} diff --git a/pkg/sentry/fs/gofer/cache_policy.go b/pkg/sentry/fs/gofer/cache_policy.go index 07a564e92..f8b7a60fc 100644 --- a/pkg/sentry/fs/gofer/cache_policy.go +++ b/pkg/sentry/fs/gofer/cache_policy.go @@ -139,7 +139,7 @@ func (cp cachePolicy) revalidate(ctx context.Context, name string, parent, child // Walk from parent to child again. // - // TODO(b/112031682): If we have a directory FD in the parent + // NOTE(b/112031682): If we have a directory FD in the parent // inodeOperations, then we can use fstatat(2) to get the inode // attributes instead of making this RPC. qids, f, mask, attr, err := parentIops.fileState.file.walkGetAttr(ctx, []string{name}) diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 91ec4a142..eb09d54c3 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -1194,11 +1194,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st // RenameAt implements vfs.FilesystemImpl.RenameAt. func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { - if opts.Flags != 0 { - // Requires 9P support. - return syserror.EINVAL - } - + // Resolve newParent first to verify that it's on this Mount. var ds *[]*dentry fs.renameMu.Lock() defer fs.renameMuUnlockAndCheckCaching(ctx, &ds) @@ -1206,8 +1202,21 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if err != nil { return err } + + if opts.Flags&^linux.RENAME_NOREPLACE != 0 { + return syserror.EINVAL + } + if fs.opts.interop == InteropModeShared && opts.Flags&linux.RENAME_NOREPLACE != 0 { + // Requires 9P support to synchronize with other remote filesystem + // users. + return syserror.EINVAL + } + newName := rp.Component() if newName == "." || newName == ".." { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } return syserror.EBUSY } mnt := rp.Mount() @@ -1280,6 +1289,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } var replacedVFSD *vfs.Dentry if replaced != nil { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } replacedVFSD = &replaced.vfsd if replaced.isDir() { if !renamed.isDir() { diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index f50b0fb08..8fac53c60 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -635,12 +635,6 @@ func (fs *Filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st // RenameAt implements vfs.FilesystemImpl.RenameAt. func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { - // Only RENAME_NOREPLACE is supported. - if opts.Flags&^linux.RENAME_NOREPLACE != 0 { - return syserror.EINVAL - } - noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0 - fs.mu.Lock() defer fs.processDeferredDecRefs(ctx) defer fs.mu.Unlock() @@ -651,6 +645,13 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if err != nil { return err } + + // Only RENAME_NOREPLACE is supported. + if opts.Flags&^linux.RENAME_NOREPLACE != 0 { + return syserror.EINVAL + } + noReplace := opts.Flags&linux.RENAME_NOREPLACE != 0 + mnt := rp.Mount() if mnt != oldParentVD.Mount() { return syserror.EXDEV diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 46c500427..6b6fa0bd5 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -1017,10 +1017,7 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st // RenameAt implements vfs.FilesystemImpl.RenameAt. func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { - if opts.Flags != 0 { - return syserror.EINVAL - } - + // Resolve newParent first to verify that it's on this Mount. var ds *[]*dentry fs.renameMu.Lock() defer fs.renameMuUnlockAndCheckDrop(ctx, &ds) @@ -1028,8 +1025,16 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if err != nil { return err } + + if opts.Flags&^linux.RENAME_NOREPLACE != 0 { + return syserror.EINVAL + } + newName := rp.Component() if newName == "." || newName == ".." { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } return syserror.EBUSY } mnt := rp.Mount() @@ -1093,6 +1098,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa return err } if replaced != nil { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } replacedVFSD = &replaced.vfsd if replaced.isDir() { if !renamed.isDir() { diff --git a/pkg/sentry/fsimpl/testutil/BUILD b/pkg/sentry/fsimpl/testutil/BUILD index c766164c7..b3f9d1010 100644 --- a/pkg/sentry/fsimpl/testutil/BUILD +++ b/pkg/sentry/fsimpl/testutil/BUILD @@ -17,7 +17,6 @@ go_library( "//pkg/fspath", "//pkg/hostarch", "//pkg/memutil", - "//pkg/metric", "//pkg/sentry/fsbridge", "//pkg/sentry/fsimpl/tmpfs", "//pkg/sentry/kernel", diff --git a/pkg/sentry/fsimpl/testutil/kernel.go b/pkg/sentry/fsimpl/testutil/kernel.go index 438840ae2..97aa20cd1 100644 --- a/pkg/sentry/fsimpl/testutil/kernel.go +++ b/pkg/sentry/fsimpl/testutil/kernel.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/cpuid" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/memutil" - "gvisor.dev/gvisor/pkg/metric" "gvisor.dev/gvisor/pkg/sentry/fsbridge" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -63,8 +62,6 @@ func Boot() (*kernel.Kernel, error) { return nil, fmt.Errorf("creating platform: %v", err) } - metric.CreateSentryMetrics() - kernel.VFS2Enabled = true k := &kernel.Kernel{ Platform: plat, diff --git a/pkg/sentry/fsimpl/tmpfs/filesystem.go b/pkg/sentry/fsimpl/tmpfs/filesystem.go index 766289e60..ee7ff2961 100644 --- a/pkg/sentry/fsimpl/tmpfs/filesystem.go +++ b/pkg/sentry/fsimpl/tmpfs/filesystem.go @@ -496,20 +496,24 @@ func (fs *filesystem) ReadlinkAt(ctx context.Context, rp *vfs.ResolvingPath) (st // RenameAt implements vfs.FilesystemImpl.RenameAt. func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldParentVD vfs.VirtualDentry, oldName string, opts vfs.RenameOptions) error { - if opts.Flags != 0 { - // TODO(b/145974740): Support renameat2 flags. - return syserror.EINVAL - } - - // Resolve newParent first to verify that it's on this Mount. + // Resolve newParentDir first to verify that it's on this Mount. fs.mu.Lock() defer fs.mu.Unlock() newParentDir, err := walkParentDirLocked(ctx, rp, rp.Start().Impl().(*dentry)) if err != nil { return err } + + if opts.Flags&^linux.RENAME_NOREPLACE != 0 { + // TODO(b/145974740): Support other renameat2 flags. + return syserror.EINVAL + } + newName := rp.Component() if newName == "." || newName == ".." { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } return syserror.EBUSY } mnt := rp.Mount() @@ -556,6 +560,9 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa } replaced, ok := newParentDir.childMap[newName] if ok { + if opts.Flags&linux.RENAME_NOREPLACE != 0 { + return syserror.EEXIST + } replacedDir, ok := replaced.inode.impl.(*directory) if ok { if !renamed.inode.isDir() { diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index a1ec6daab..188c0ebff 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -32,7 +32,7 @@ go_template_instance( out = "seqatomic_taskgoroutineschedinfo_unsafe.go", package = "kernel", suffix = "TaskGoroutineSchedInfo", - template = "//pkg/sync:generic_seqatomic", + template = "//pkg/sync/seqatomic:generic_seqatomic", types = { "Value": "TaskGoroutineSchedInfo", }, diff --git a/pkg/sentry/kernel/auth/BUILD b/pkg/sentry/kernel/auth/BUILD index 869e49ebc..12180351d 100644 --- a/pkg/sentry/kernel/auth/BUILD +++ b/pkg/sentry/kernel/auth/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "atomicptr_credentials_unsafe.go", package = "auth", suffix = "Credentials", - template = "//pkg/sync:generic_atomicptr", + template = "//pkg/sync/atomicptr:generic_atomicptr", types = { "Value": "Credentials", }, diff --git a/pkg/sentry/kernel/auth/credentials.go b/pkg/sentry/kernel/auth/credentials.go index 6862f2ef5..3325fedcb 100644 --- a/pkg/sentry/kernel/auth/credentials.go +++ b/pkg/sentry/kernel/auth/credentials.go @@ -125,7 +125,7 @@ func NewUserCredentials(kuid KUID, kgid KGID, extraKGIDs []KGID, capabilities *T creds.EffectiveCaps = capabilities.EffectiveCaps creds.BoundingCaps = capabilities.BoundingCaps creds.InheritableCaps = capabilities.InheritableCaps - // TODO(nlacasse): Support ambient capabilities. + // TODO(gvisor.dev/issue/3166): Support ambient capabilities. } else { // If no capabilities are specified, grant capabilities consistent with // setresuid + setresgid from NewRootCredentials to the given uid and diff --git a/pkg/sentry/kernel/futex/BUILD b/pkg/sentry/kernel/futex/BUILD index a75686cf3..6c31e082c 100644 --- a/pkg/sentry/kernel/futex/BUILD +++ b/pkg/sentry/kernel/futex/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "atomicptr_bucket_unsafe.go", package = "futex", suffix = "Bucket", - template = "//pkg/sync:generic_atomicptr", + template = "//pkg/sync/atomicptr:generic_atomicptr", types = { "Value": "bucket", }, diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index febe7fe50..c666be2cb 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -1861,7 +1861,9 @@ func (k *Kernel) PopulateNewCgroupHierarchy(root Cgroup) { return } t.mu.Lock() - t.enterCgroupLocked(root) + // A task can be in the cgroup if it has been created after the + // cgroup hierarchy was registered. + t.enterCgroupIfNotYetLocked(root) t.mu.Unlock() }) k.tasks.mu.RUnlock() diff --git a/pkg/sentry/kernel/signal_handlers.go b/pkg/sentry/kernel/signal_handlers.go index 768fda220..147cc41bb 100644 --- a/pkg/sentry/kernel/signal_handlers.go +++ b/pkg/sentry/kernel/signal_handlers.go @@ -16,7 +16,6 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sync" ) @@ -30,14 +29,14 @@ type SignalHandlers struct { mu sync.Mutex `state:"nosave"` // actions is the action to be taken upon receiving each signal. - actions map[linux.Signal]arch.SignalAct + actions map[linux.Signal]linux.SigAction } // NewSignalHandlers returns a new SignalHandlers specifying all default // actions. func NewSignalHandlers() *SignalHandlers { return &SignalHandlers{ - actions: make(map[linux.Signal]arch.SignalAct), + actions: make(map[linux.Signal]linux.SigAction), } } @@ -59,9 +58,9 @@ func (sh *SignalHandlers) CopyForExec() *SignalHandlers { sh.mu.Lock() defer sh.mu.Unlock() for sig, act := range sh.actions { - if act.Handler == arch.SignalActIgnore { - sh2.actions[sig] = arch.SignalAct{ - Handler: arch.SignalActIgnore, + if act.Handler == linux.SIG_IGN { + sh2.actions[sig] = linux.SigAction{ + Handler: linux.SIG_IGN, } } } @@ -73,15 +72,15 @@ func (sh *SignalHandlers) IsIgnored(sig linux.Signal) bool { sh.mu.Lock() defer sh.mu.Unlock() sa, ok := sh.actions[sig] - return ok && sa.Handler == arch.SignalActIgnore + return ok && sa.Handler == linux.SIG_IGN } // dequeueActionLocked returns the SignalAct that should be used to handle sig. // // Preconditions: sh.mu must be locked. -func (sh *SignalHandlers) dequeueAction(sig linux.Signal) arch.SignalAct { +func (sh *SignalHandlers) dequeueAction(sig linux.Signal) linux.SigAction { act := sh.actions[sig] - if act.IsResetHandler() { + if act.Flags&linux.SA_RESETHAND != 0 { delete(sh.actions, sig) } return act diff --git a/pkg/sentry/kernel/task.go b/pkg/sentry/kernel/task.go index be1371855..9290dc52b 100644 --- a/pkg/sentry/kernel/task.go +++ b/pkg/sentry/kernel/task.go @@ -151,7 +151,7 @@ type Task struct { // which the SA_ONSTACK flag is set. // // signalStack is exclusive to the task goroutine. - signalStack arch.SignalStack + signalStack linux.SignalStack // signalQueue is a set of registered waiters for signal-related events. // diff --git a/pkg/sentry/kernel/task_cgroup.go b/pkg/sentry/kernel/task_cgroup.go index 25d2504fa..7c138e80f 100644 --- a/pkg/sentry/kernel/task_cgroup.go +++ b/pkg/sentry/kernel/task_cgroup.go @@ -85,6 +85,14 @@ func (t *Task) enterCgroupLocked(c Cgroup) { c.Enter(t) } +// +checklocks:t.mu +func (t *Task) enterCgroupIfNotYetLocked(c Cgroup) { + if _, ok := t.cgroups[c]; ok { + return + } + t.enterCgroupLocked(c) +} + // LeaveCgroups removes t out from all its cgroups. func (t *Task) LeaveCgroups() { t.mu.Lock() diff --git a/pkg/sentry/kernel/task_exec.go b/pkg/sentry/kernel/task_exec.go index d9897e802..cf8571262 100644 --- a/pkg/sentry/kernel/task_exec.go +++ b/pkg/sentry/kernel/task_exec.go @@ -66,7 +66,6 @@ package kernel import ( "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -181,7 +180,7 @@ func (r *runSyscallAfterExecStop) execute(t *Task) taskRunState { t.tg.signalHandlers = t.tg.signalHandlers.CopyForExec() t.endStopCond.L = &t.tg.signalHandlers.mu // "Any alternate signal stack is not preserved (sigaltstack(2))." - execve(2) - t.signalStack = arch.SignalStack{Flags: arch.SignalStackFlagDisable} + t.signalStack = linux.SignalStack{Flags: linux.SS_DISABLE} // "The termination signal is reset to SIGCHLD (see clone(2))." t.tg.terminationSignal = linux.SIGCHLD // execed indicates that the process can no longer join a process group diff --git a/pkg/sentry/kernel/task_exit.go b/pkg/sentry/kernel/task_exit.go index b1af1a7ef..5b17c0065 100644 --- a/pkg/sentry/kernel/task_exit.go +++ b/pkg/sentry/kernel/task_exit.go @@ -28,6 +28,7 @@ import ( "errors" "fmt" "strconv" + "strings" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/arch" @@ -50,6 +51,23 @@ type ExitStatus struct { Signo int } +func (es ExitStatus) String() string { + var b strings.Builder + if code := es.Code; code != 0 { + if b.Len() != 0 { + b.WriteByte(' ') + } + _, _ = fmt.Fprintf(&b, "Code=%d", code) + } + if signal := es.Signo; signal != 0 { + if b.Len() != 0 { + b.WriteByte(' ') + } + _, _ = fmt.Fprintf(&b, "Signal=%d", signal) + } + return b.String() +} + // Signaled returns true if the ExitStatus indicates that the exiting task or // thread group was killed by a signal. func (es ExitStatus) Signaled() bool { @@ -652,10 +670,10 @@ func (t *Task) exitNotifyLocked(fromPtraceDetach bool) { t.parent.tg.signalHandlers.mu.Lock() if t.tg.terminationSignal == linux.SIGCHLD || fromPtraceDetach { if act, ok := t.parent.tg.signalHandlers.actions[linux.SIGCHLD]; ok { - if act.Handler == arch.SignalActIgnore { + if act.Handler == linux.SIG_IGN { t.exitParentAcked = true signalParent = false - } else if act.Flags&arch.SignalFlagNoCldWait != 0 { + } else if act.Flags&linux.SA_NOCLDWAIT != 0 { t.exitParentAcked = true } } diff --git a/pkg/sentry/kernel/task_signals.go b/pkg/sentry/kernel/task_signals.go index c2b9fc08f..b0ed0e023 100644 --- a/pkg/sentry/kernel/task_signals.go +++ b/pkg/sentry/kernel/task_signals.go @@ -86,7 +86,7 @@ var defaultActions = map[linux.Signal]SignalAction{ } // computeAction figures out what to do given a signal number -// and an arch.SignalAct. SIGSTOP always results in a SignalActionStop, +// and an linux.SigAction. SIGSTOP always results in a SignalActionStop, // and SIGKILL always results in a SignalActionTerm. // Signal 0 is always ignored as many programs use it for various internal functions // and don't expect it to do anything. @@ -97,7 +97,7 @@ var defaultActions = map[linux.Signal]SignalAction{ // 0, the default action is taken; // 1, the signal is ignored; // anything else, the function returns SignalActionHandler. -func computeAction(sig linux.Signal, act arch.SignalAct) SignalAction { +func computeAction(sig linux.Signal, act linux.SigAction) SignalAction { switch sig { case linux.SIGSTOP: return SignalActionStop @@ -108,9 +108,9 @@ func computeAction(sig linux.Signal, act arch.SignalAct) SignalAction { } switch act.Handler { - case arch.SignalActDefault: + case linux.SIG_DFL: return defaultActions[sig] - case arch.SignalActIgnore: + case linux.SIG_IGN: return SignalActionIgnore default: return SignalActionHandler @@ -155,7 +155,7 @@ func (t *Task) PendingSignals() linux.SignalSet { } // deliverSignal delivers the given signal and returns the following run state. -func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunState { +func (t *Task) deliverSignal(info *arch.SignalInfo, act linux.SigAction) taskRunState { sigact := computeAction(linux.Signal(info.Signo), act) if t.haveSyscallReturn { @@ -172,7 +172,7 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS fallthrough case sre == syserror.ERESTART_RESTARTBLOCK: fallthrough - case (sre == syserror.ERESTARTSYS && !act.IsRestart()): + case (sre == syserror.ERESTARTSYS && act.Flags&linux.SA_RESTART == 0): t.Debugf("Not restarting syscall %d after errno %d: interrupted by signal %d", t.Arch().SyscallNo(), sre, info.Signo) t.Arch().SetReturn(uintptr(-ExtractErrno(syserror.EINTR, -1))) default: @@ -236,7 +236,7 @@ func (t *Task) deliverSignal(info *arch.SignalInfo, act arch.SignalAct) taskRunS // deliverSignalToHandler changes the task's userspace state to enter the given // user-configured handler for the given signal. -func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) error { +func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act linux.SigAction) error { // Signal delivery to an application handler interrupts restartable // sequences. t.rseqInterrupt() @@ -248,8 +248,8 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) // N.B. This is a *copy* of the alternate stack that the user's signal // handler expects to see in its ucontext (even if it's not in use). alt := t.signalStack - if act.IsOnStack() && alt.IsEnabled() { - alt.SetOnStack() + if act.Flags&linux.SA_ONSTACK != 0 && alt.IsEnabled() { + alt.Flags |= linux.SS_ONSTACK if !alt.Contains(sp) { sp = hostarch.Addr(alt.Top()) } @@ -289,7 +289,7 @@ func (t *Task) deliverSignalToHandler(info *arch.SignalInfo, act arch.SignalAct) // Add our signal mask. newMask := t.signalMask | act.Mask - if !act.IsNoDefer() { + if act.Flags&linux.SA_NODEFER == 0 { newMask |= linux.SignalSetOf(linux.Signal(info.Signo)) } t.SetSignalMask(newMask) @@ -572,9 +572,9 @@ func (t *Task) forceSignal(sig linux.Signal, unconditional bool) { func (t *Task) forceSignalLocked(sig linux.Signal, unconditional bool) { blocked := linux.SignalSetOf(sig)&t.signalMask != 0 act := t.tg.signalHandlers.actions[sig] - ignored := act.Handler == arch.SignalActIgnore + ignored := act.Handler == linux.SIG_IGN if blocked || ignored || unconditional { - act.Handler = arch.SignalActDefault + act.Handler = linux.SIG_DFL t.tg.signalHandlers.actions[sig] = act if blocked { t.setSignalMaskLocked(t.signalMask &^ linux.SignalSetOf(sig)) @@ -641,17 +641,17 @@ func (t *Task) SetSavedSignalMask(mask linux.SignalSet) { } // SignalStack returns the task-private signal stack. -func (t *Task) SignalStack() arch.SignalStack { +func (t *Task) SignalStack() linux.SignalStack { t.p.PullFullState(t.MemoryManager().AddressSpace(), t.Arch()) alt := t.signalStack if t.onSignalStack(alt) { - alt.Flags |= arch.SignalStackFlagOnStack + alt.Flags |= linux.SS_ONSTACK } return alt } // onSignalStack returns true if the task is executing on the given signal stack. -func (t *Task) onSignalStack(alt arch.SignalStack) bool { +func (t *Task) onSignalStack(alt linux.SignalStack) bool { sp := hostarch.Addr(t.Arch().Stack()) return alt.Contains(sp) } @@ -661,30 +661,30 @@ func (t *Task) onSignalStack(alt arch.SignalStack) bool { // This value may not be changed if the task is currently executing on the // signal stack, i.e. if t.onSignalStack returns true. In this case, this // function will return false. Otherwise, true is returned. -func (t *Task) SetSignalStack(alt arch.SignalStack) bool { +func (t *Task) SetSignalStack(alt linux.SignalStack) bool { // Check that we're not executing on the stack. if t.onSignalStack(t.signalStack) { return false } - if alt.Flags&arch.SignalStackFlagDisable != 0 { + if alt.Flags&linux.SS_DISABLE != 0 { // Don't record anything beyond the flags. - t.signalStack = arch.SignalStack{ - Flags: arch.SignalStackFlagDisable, + t.signalStack = linux.SignalStack{ + Flags: linux.SS_DISABLE, } } else { // Mask out irrelevant parts: only disable matters. - alt.Flags &= arch.SignalStackFlagDisable + alt.Flags &= linux.SS_DISABLE t.signalStack = alt } return true } -// SetSignalAct atomically sets the thread group's signal action for signal sig +// SetSigAction atomically sets the thread group's signal action for signal sig // to *actptr (if actptr is not nil) and returns the old signal action. -func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (arch.SignalAct, error) { +func (tg *ThreadGroup) SetSigAction(sig linux.Signal, actptr *linux.SigAction) (linux.SigAction, error) { if !sig.IsValid() { - return arch.SignalAct{}, syserror.EINVAL + return linux.SigAction{}, syserror.EINVAL } tg.pidns.owner.mu.RLock() @@ -718,48 +718,6 @@ func (tg *ThreadGroup) SetSignalAct(sig linux.Signal, actptr *arch.SignalAct) (a return oldact, nil } -// CopyOutSignalAct converts the given SignalAct into an architecture-specific -// type and then copies it out to task memory. -func (t *Task) CopyOutSignalAct(addr hostarch.Addr, s *arch.SignalAct) error { - n := t.Arch().NewSignalAct() - n.SerializeFrom(s) - _, err := n.CopyOut(t, addr) - return err -} - -// CopyInSignalAct copies an architecture-specific sigaction type from task -// memory and then converts it into a SignalAct. -func (t *Task) CopyInSignalAct(addr hostarch.Addr) (arch.SignalAct, error) { - n := t.Arch().NewSignalAct() - var s arch.SignalAct - if _, err := n.CopyIn(t, addr); err != nil { - return s, err - } - n.DeserializeTo(&s) - return s, nil -} - -// CopyOutSignalStack converts the given SignalStack into an -// architecture-specific type and then copies it out to task memory. -func (t *Task) CopyOutSignalStack(addr hostarch.Addr, s *arch.SignalStack) error { - n := t.Arch().NewSignalStack() - n.SerializeFrom(s) - _, err := n.CopyOut(t, addr) - return err -} - -// CopyInSignalStack copies an architecture-specific stack_t from task memory -// and then converts it into a SignalStack. -func (t *Task) CopyInSignalStack(addr hostarch.Addr) (arch.SignalStack, error) { - n := t.Arch().NewSignalStack() - var s arch.SignalStack - if _, err := n.CopyIn(t, addr); err != nil { - return s, err - } - n.DeserializeTo(&s) - return s, nil -} - // groupStop is a TaskStop placed on tasks that have received a stop signal // (SIGSTOP, SIGTSTP, SIGTTIN, SIGTTOU). (The term "group-stop" originates from // the ptrace man page.) @@ -909,7 +867,7 @@ func (t *Task) signalStop(target *Task, code int32, status int32) { t.tg.signalHandlers.mu.Lock() defer t.tg.signalHandlers.mu.Unlock() act, ok := t.tg.signalHandlers.actions[linux.SIGCHLD] - if !ok || (act.Handler != arch.SignalActIgnore && act.Flags&arch.SignalFlagNoCldStop == 0) { + if !ok || (act.Handler != linux.SIG_IGN && act.Flags&linux.SA_NOCLDSTOP == 0) { sigchld := &arch.SignalInfo{ Signo: int32(linux.SIGCHLD), Code: code, diff --git a/pkg/sentry/kernel/task_start.go b/pkg/sentry/kernel/task_start.go index 32031cd70..41fd2d471 100644 --- a/pkg/sentry/kernel/task_start.go +++ b/pkg/sentry/kernel/task_start.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/hostarch" - "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/futex" @@ -131,7 +130,7 @@ func (ts *TaskSet) newTask(cfg *TaskConfig) (*Task, error) { runState: (*runApp)(nil), interruptChan: make(chan struct{}, 1), signalMask: cfg.SignalMask, - signalStack: arch.SignalStack{Flags: arch.SignalStackFlagDisable}, + signalStack: linux.SignalStack{Flags: linux.SS_DISABLE}, image: *image, fsContext: cfg.FSContext, fdTable: cfg.FDTable, diff --git a/pkg/sentry/kernel/thread_group.go b/pkg/sentry/kernel/thread_group.go index b92e98fa1..e22ddcd21 100644 --- a/pkg/sentry/kernel/thread_group.go +++ b/pkg/sentry/kernel/thread_group.go @@ -490,10 +490,10 @@ func (tg *ThreadGroup) SetForegroundProcessGroup(tty *TTY, pgid ProcessGroupID) tg.signalHandlers.mu.Lock() defer tg.signalHandlers.mu.Unlock() - // TODO(b/129283598): "If tcsetpgrp() is called by a member of a - // background process group in its session, and the calling process is - // not blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all - // members of this background process group." + // TODO(gvisor.dev/issue/6148): "If tcsetpgrp() is called by a member of a + // background process group in its session, and the calling process is not + // blocking or ignoring SIGTTOU, a SIGTTOU signal is sent to all members of + // this background process group." // tty must be the controlling terminal. if tg.tty != tty { diff --git a/pkg/sentry/pgalloc/pgalloc.go b/pkg/sentry/pgalloc/pgalloc.go index b81292c46..d1a883da4 100644 --- a/pkg/sentry/pgalloc/pgalloc.go +++ b/pkg/sentry/pgalloc/pgalloc.go @@ -1062,10 +1062,20 @@ func (f *MemoryFile) runReclaim() { break } - // If ManualZeroing is in effect, pages will be zeroed on allocation - // and may not be freed by decommitFile, so calling decommitFile is - // unnecessary. - if !f.opts.ManualZeroing { + if f.opts.ManualZeroing { + // If ManualZeroing is in effect, only hugepage-aligned regions may + // be safely passed to decommitFile. Pages will be zeroed on + // reallocation, so we don't need to perform any manual zeroing + // here, whether or not decommitFile succeeds. + if startAddr, ok := hostarch.Addr(fr.Start).HugeRoundUp(); ok { + if endAddr := hostarch.Addr(fr.End).HugeRoundDown(); startAddr < endAddr { + decommitFR := memmap.FileRange{uint64(startAddr), uint64(endAddr)} + if err := f.decommitFile(decommitFR); err != nil { + log.Warningf("Reclaim failed to decommit %v: %v", decommitFR, err) + } + } + } + } else { if err := f.decommitFile(fr); err != nil { log.Warningf("Reclaim failed to decommit %v: %v", fr, err) // Zero the pages manually. This won't reduce memory usage, but at diff --git a/pkg/sentry/sighandling/sighandling_unsafe.go b/pkg/sentry/sighandling/sighandling_unsafe.go index d6a2fbe34..3fe5c6770 100644 --- a/pkg/sentry/sighandling/sighandling_unsafe.go +++ b/pkg/sentry/sighandling/sighandling_unsafe.go @@ -21,25 +21,16 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" ) -// FIXME(gvisor.dev/issue/214): Move to pkg/abi/linux along with definitions in -// pkg/sentry/arch. -type sigaction struct { - handler uintptr - flags uint64 - restorer uintptr - mask uint64 -} - // IgnoreChildStop sets the SA_NOCLDSTOP flag, causing child processes to not // generate SIGCHLD when they stop. func IgnoreChildStop() error { - var sa sigaction + var sa linux.SigAction // Get the existing signal handler information, and set the flag. if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(unix.SIGCHLD), 0, uintptr(unsafe.Pointer(&sa)), linux.SignalSetSize, 0, 0); e != 0 { return e } - sa.flags |= linux.SA_NOCLDSTOP + sa.Flags |= linux.SA_NOCLDSTOP if _, _, e := unix.RawSyscall6(unix.SYS_RT_SIGACTION, uintptr(unix.SIGCHLD), uintptr(unsafe.Pointer(&sa)), 0, linux.SignalSetSize, 0, 0); e != 0 { return e } diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 037ccfec8..d4b1bad67 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1798,11 +1798,6 @@ func setSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, nam } v := hostarch.ByteOrder.Uint32(optVal) - - if v == 0 { - socket.SetSockOptEmitUnimplementedEvent(t, name) - } - ep.SocketOptions().SetOutOfBandInline(v != 0) return nil diff --git a/pkg/sentry/strace/signal.go b/pkg/sentry/strace/signal.go index e5b379a20..5afc9525b 100644 --- a/pkg/sentry/strace/signal.go +++ b/pkg/sentry/strace/signal.go @@ -130,8 +130,8 @@ func sigAction(t *kernel.Task, addr hostarch.Addr) string { return "null" } - sa, err := t.CopyInSignalAct(addr) - if err != nil { + var sa linux.SigAction + if _, err := sa.CopyIn(t, addr); err != nil { return fmt.Sprintf("%#x (error copying sigaction: %v)", addr, err) } diff --git a/pkg/sentry/syscalls/linux/sys_signal.go b/pkg/sentry/syscalls/linux/sys_signal.go index 53b12dc41..39a333215 100644 --- a/pkg/sentry/syscalls/linux/sys_signal.go +++ b/pkg/sentry/syscalls/linux/sys_signal.go @@ -251,20 +251,20 @@ func RtSigaction(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S return 0, nil, syserror.EINVAL } - var newactptr *arch.SignalAct + var newactptr *linux.SigAction if newactarg != 0 { - newact, err := t.CopyInSignalAct(newactarg) - if err != nil { + var newact linux.SigAction + if _, err := newact.CopyIn(t, newactarg); err != nil { return 0, nil, err } newactptr = &newact } - oldact, err := t.ThreadGroup().SetSignalAct(sig, newactptr) + oldact, err := t.ThreadGroup().SetSigAction(sig, newactptr) if err != nil { return 0, nil, err } if oldactarg != 0 { - if err := t.CopyOutSignalAct(oldactarg, &oldact); err != nil { + if _, err := oldact.CopyOut(t, oldactarg); err != nil { return 0, nil, err } } @@ -325,13 +325,12 @@ func Sigaltstack(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.S alt := t.SignalStack() if oldaddr != 0 { - if err := t.CopyOutSignalStack(oldaddr, &alt); err != nil { + if _, err := alt.CopyOut(t, oldaddr); err != nil { return 0, nil, err } } if setaddr != 0 { - alt, err := t.CopyInSignalStack(setaddr) - if err != nil { + if _, err := alt.CopyIn(t, setaddr); err != nil { return 0, nil, err } // The signal stack cannot be changed if the task is currently diff --git a/pkg/sentry/time/BUILD b/pkg/sentry/time/BUILD index 1f617ca8f..362dea76d 100644 --- a/pkg/sentry/time/BUILD +++ b/pkg/sentry/time/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "seqatomic_parameters_unsafe.go", package = "time", suffix = "Parameters", - template = "//pkg/sync:generic_seqatomic", + template = "//pkg/sync/seqatomic:generic_seqatomic", types = { "Value": "Parameters", }, diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 82fd382c2..f93da3af1 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -220,7 +220,7 @@ func (vfs *VirtualFilesystem) ConnectMountAt(ctx context.Context, creds *auth.Cr vdDentry := vd.dentry vdDentry.mu.Lock() for { - if vdDentry.dead { + if vd.mount.umounted || vdDentry.dead { vdDentry.mu.Unlock() vfs.mountMu.Unlock() vd.DecRef(ctx) diff --git a/pkg/sync/BUILD b/pkg/sync/BUILD index 8b3a11c64..73791b456 100644 --- a/pkg/sync/BUILD +++ b/pkg/sync/BUILD @@ -1,5 +1,4 @@ load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template") package( default_visibility = ["//:sandbox"], @@ -8,45 +7,6 @@ package( exports_files(["LICENSE"]) -go_template( - name = "generic_atomicptr", - srcs = ["generic_atomicptr_unsafe.go"], - types = [ - "Value", - ], -) - -go_template( - name = "generic_atomicptrmap", - srcs = ["generic_atomicptrmap_unsafe.go"], - opt_consts = [ - "ShardOrder", - ], - opt_types = [ - "Hasher", - ], - types = [ - "Key", - "Value", - ], - deps = [ - ":sync", - "//pkg/gohacks", - ], -) - -go_template( - name = "generic_seqatomic", - srcs = ["generic_seqatomic_unsafe.go"], - types = [ - "Value", - ], - deps = [ - ":sync", - "//pkg/gohacks", - ], -) - go_library( name = "sync", srcs = [ diff --git a/pkg/sync/atomicptrtest/BUILD b/pkg/sync/atomicptr/BUILD index e97553254..a6a7f01ac 100644 --- a/pkg/sync/atomicptrtest/BUILD +++ b/pkg/sync/atomicptr/BUILD @@ -1,14 +1,23 @@ load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) +go_template( + name = "generic_atomicptr", + srcs = ["generic_atomicptr_unsafe.go"], + types = [ + "Value", + ], + visibility = ["//:sandbox"], +) + go_template_instance( name = "atomicptr_int", out = "atomicptr_int_unsafe.go", package = "atomicptr", suffix = "Int", - template = "//pkg/sync:generic_atomicptr", + template = ":generic_atomicptr", types = { "Value": "int", }, diff --git a/pkg/sync/atomicptrtest/atomicptr_test.go b/pkg/sync/atomicptr/atomicptr_test.go index 8fdc5112e..8fdc5112e 100644 --- a/pkg/sync/atomicptrtest/atomicptr_test.go +++ b/pkg/sync/atomicptr/atomicptr_test.go diff --git a/pkg/sync/generic_atomicptr_unsafe.go b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go index 82b6df18c..82b6df18c 100644 --- a/pkg/sync/generic_atomicptr_unsafe.go +++ b/pkg/sync/atomicptr/generic_atomicptr_unsafe.go diff --git a/pkg/sync/atomicptrmaptest/BUILD b/pkg/sync/atomicptrmap/BUILD index 3f71ae97d..b0e218c79 100644 --- a/pkg/sync/atomicptrmaptest/BUILD +++ b/pkg/sync/atomicptrmap/BUILD @@ -1,17 +1,36 @@ load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package( default_visibility = ["//visibility:private"], licenses = ["notice"], ) +go_template( + name = "generic_atomicptrmap", + srcs = ["generic_atomicptrmap_unsafe.go"], + opt_consts = [ + "ShardOrder", + ], + opt_types = [ + "Hasher", + ], + types = [ + "Key", + "Value", + ], + deps = [ + "//pkg/gohacks", + "//pkg/sync", + ], +) + go_template_instance( name = "test_atomicptrmap", out = "test_atomicptrmap_unsafe.go", package = "atomicptrmap", prefix = "test", - template = "//pkg/sync:generic_atomicptrmap", + template = ":generic_atomicptrmap", types = { "Key": "int64", "Value": "testValue", @@ -27,7 +46,7 @@ go_template_instance( package = "atomicptrmap", prefix = "test", suffix = "Sharded", - template = "//pkg/sync:generic_atomicptrmap", + template = ":generic_atomicptrmap", types = { "Key": "int64", "Value": "testValue", diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap.go b/pkg/sync/atomicptrmap/atomicptrmap.go index 867821ce9..867821ce9 100644 --- a/pkg/sync/atomicptrmaptest/atomicptrmap.go +++ b/pkg/sync/atomicptrmap/atomicptrmap.go diff --git a/pkg/sync/atomicptrmaptest/atomicptrmap_test.go b/pkg/sync/atomicptrmap/atomicptrmap_test.go index 75a9997ef..75a9997ef 100644 --- a/pkg/sync/atomicptrmaptest/atomicptrmap_test.go +++ b/pkg/sync/atomicptrmap/atomicptrmap_test.go diff --git a/pkg/sync/generic_atomicptrmap_unsafe.go b/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go index 3e98cb309..3e98cb309 100644 --- a/pkg/sync/generic_atomicptrmap_unsafe.go +++ b/pkg/sync/atomicptrmap/generic_atomicptrmap_unsafe.go diff --git a/pkg/sync/seqatomictest/BUILD b/pkg/sync/seqatomic/BUILD index 5f9164117..60f79ab54 100644 --- a/pkg/sync/seqatomictest/BUILD +++ b/pkg/sync/seqatomic/BUILD @@ -1,14 +1,27 @@ load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_generics:defs.bzl", "go_template_instance") +load("//tools/go_generics:defs.bzl", "go_template", "go_template_instance") package(licenses = ["notice"]) +go_template( + name = "generic_seqatomic", + srcs = ["generic_seqatomic_unsafe.go"], + types = [ + "Value", + ], + visibility = ["//:sandbox"], + deps = [ + ":sync", + "//pkg/gohacks", + ], +) + go_template_instance( name = "seqatomic_int", out = "seqatomic_int_unsafe.go", package = "seqatomic", suffix = "Int", - template = "//pkg/sync:generic_seqatomic", + template = ":generic_seqatomic", types = { "Value": "int", }, diff --git a/pkg/sync/generic_seqatomic_unsafe.go b/pkg/sync/seqatomic/generic_seqatomic_unsafe.go index 9578c9c52..9578c9c52 100644 --- a/pkg/sync/generic_seqatomic_unsafe.go +++ b/pkg/sync/seqatomic/generic_seqatomic_unsafe.go diff --git a/pkg/sync/seqatomictest/seqatomic_test.go b/pkg/sync/seqatomic/seqatomic_test.go index 2c4568b07..2c4568b07 100644 --- a/pkg/sync/seqatomictest/seqatomic_test.go +++ b/pkg/sync/seqatomic/seqatomic_test.go diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index bddb1d0a2..735c28da1 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -41,7 +41,6 @@ package fdbased import ( "fmt" - "math" "sync/atomic" "golang.org/x/sys/unix" @@ -196,8 +195,12 @@ type Options struct { // option for an FD with a fanoutID already in use by another FD for a different // NIC will return an EINVAL. // +// Since fanoutID must be unique within the network namespace, we start with +// the PID to avoid collisions. The only way to be sure of avoiding collisions +// is to run in a new network namespace. +// // Must be accessed using atomic operations. -var fanoutID int32 = 0 +var fanoutID int32 = int32(unix.Getpid()) // New creates a new fd-based endpoint. // @@ -292,11 +295,6 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin } switch sa.(type) { case *unix.SockaddrLinklayer: - // See: PACKET_FANOUT_MAX in net/packet/internal.h - const packetFanoutMax = 1 << 16 - if fID > packetFanoutMax { - return nil, fmt.Errorf("host fanoutID limit exceeded, fanoutID must be <= %d", math.MaxUint16) - } // Enable PACKET_FANOUT mode if the underlying socket is of type // AF_PACKET. We do not enable PACKET_FANOUT_FLAG_DEFRAG as that will // prevent gvisor from receiving fragmented packets and the host does the @@ -317,7 +315,7 @@ func createInboundDispatcher(e *endpoint, fd int, isSocket bool, fID int32) (lin // // See: https://github.com/torvalds/linux/blob/7acac4b3196caee5e21fb5ea53f8bc124e6a16fc/net/packet/af_packet.c#L3881 const fanoutType = unix.PACKET_FANOUT_HASH - fanoutArg := int(fID) | fanoutType<<16 + fanoutArg := (int(fID) & 0xffff) | fanoutType<<16 if err := unix.SetsockoptInt(fd, unix.SOL_PACKET, unix.PACKET_FANOUT, fanoutArg); err != nil { return nil, fmt.Errorf("failed to enable PACKET_FANOUT option: %v", err) } diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index 6bee55634..c99297a51 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -720,7 +720,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { return nil } - ep.handleValidatedPacket(h, pkt) + // The packet originally arrived on e so provide its NIC as the input NIC. + ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) return nil } @@ -836,7 +837,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } // handleLocalPacket is like HandlePacket except it does not perform the @@ -855,10 +856,10 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum return } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } -func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer, inNICName string) { pkt.NICID = e.nic.ID() stats := e.stats stats.ip.ValidPacketsReceived.Increment() @@ -920,8 +921,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv4, pkt *stack.PacketBuffer) // iptables filtering. All packets that reach here are intended for // this machine and will not be forwarded. - inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.ip.IPTablesInputDropped.Increment() return diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 6103574f7..12763add6 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -991,7 +991,8 @@ func (e *endpoint) forwardPacket(pkt *stack.PacketBuffer) ip.ForwardingError { return nil } - ep.handleValidatedPacket(h, pkt) + // The packet originally arrived on e so provide its NIC as the input NIC. + ep.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) return nil } @@ -1104,7 +1105,7 @@ func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { } } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } // handleLocalPacket is like HandlePacket except it does not perform the @@ -1123,10 +1124,10 @@ func (e *endpoint) handleLocalPacket(pkt *stack.PacketBuffer, canSkipRXChecksum return } - e.handleValidatedPacket(h, pkt) + e.handleValidatedPacket(h, pkt, e.nic.Name() /* inNICName */) } -func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) { +func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer, inNICName string) { pkt.NICID = e.nic.ID() stats := e.stats.ip stats.ValidPacketsReceived.Increment() @@ -1175,8 +1176,7 @@ func (e *endpoint) handleValidatedPacket(h header.IPv6, pkt *stack.PacketBuffer) // iptables filtering. All packets that reach here are intended for // this machine and need not be forwarded. - inNicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNicName, "" /* outNicName */); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Input, pkt, nil, "" /* preroutingAddr */, inNICName, "" /* outNicName */); !ok { // iptables is telling us to drop the packet. stats.IPTablesInputDropped.Increment() return @@ -1627,8 +1627,8 @@ func (e *endpoint) NetworkProtocolNumber() tcpip.NetworkProtocolNumber { func (e *endpoint) AddAndAcquirePermanentAddress(addr tcpip.AddressWithPrefix, peb stack.PrimaryEndpointBehavior, configType stack.AddressConfigType, deprecated bool) (stack.AddressEndpoint, tcpip.Error) { // TODO(b/169350103): add checks here after making sure we no longer receive // an empty address. - e.mu.RLock() - defer e.mu.RUnlock() + e.mu.Lock() + defer e.mu.Unlock() return e.addAndAcquirePermanentAddressLocked(addr, peb, configType, deprecated) } @@ -1669,8 +1669,8 @@ func (e *endpoint) addAndAcquirePermanentAddressLocked(addr tcpip.AddressWithPre // RemovePermanentAddress implements stack.AddressableEndpoint. func (e *endpoint) RemovePermanentAddress(addr tcpip.Address) tcpip.Error { - e.mu.RLock() - defer e.mu.RUnlock() + e.mu.Lock() + defer e.mu.Unlock() addressEndpoint := e.getAddressRLocked(addr) if addressEndpoint == nil || !addressEndpoint.GetKind().IsPermanent() { diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 4ca702121..9192d8433 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -134,7 +134,7 @@ type PacketBuffer struct { // https://www.man7.org/linux/man-pages/man7/packet.7.html. PktType tcpip.PacketType - // NICID is the ID of the interface the network packet was received at. + // NICID is the ID of the last interface the network packet was handled at. NICID tcpip.NICID // RXTransportChecksumValidated indicates that transport checksum verification diff --git a/pkg/tcpip/tests/integration/iptables_test.go b/pkg/tcpip/tests/integration/iptables_test.go index 07ba2b837..f9ab7d0af 100644 --- a/pkg/tcpip/tests/integration/iptables_test.go +++ b/pkg/tcpip/tests/integration/iptables_test.go @@ -166,7 +166,7 @@ func TestIPTablesStatsForInput(t *testing.T) { // Make sure the packet is not dropped by the next rule. filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -187,7 +187,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{nicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -207,7 +207,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -227,7 +227,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -250,7 +250,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -273,7 +273,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -293,7 +293,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, true, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, true, err) } }, genPacket: genPacketV6, @@ -313,7 +313,7 @@ func TestIPTablesStatsForInput(t *testing.T) { filter.Rules[ruleIdx].Matchers = []stack.Matcher{&inputIfNameMatcher{anotherNicName}} filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, false, err) + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, false, err) } }, genPacket: genPacketV4, @@ -465,7 +465,7 @@ func TestIPTableWritePackets(t *testing.T) { } if err := s.IPTables().ReplaceTable(stack.FilterID, table, false /* ipv4 */); err != nil { - t.Fatalf("RelaceTable(%d, _, false): %s", stack.FilterID, err) + t.Fatalf("ReplaceTable(%d, _, false): %s", stack.FilterID, err) } }, genPacket: func(r *stack.Route) stack.PacketBufferList { @@ -556,7 +556,7 @@ func TestIPTableWritePackets(t *testing.T) { } if err := s.IPTables().ReplaceTable(stack.FilterID, table, true /* ipv6 */); err != nil { - t.Fatalf("RelaceTable(%d, _, true): %s", stack.FilterID, err) + t.Fatalf("ReplaceTable(%d, _, true): %s", stack.FilterID, err) } }, genPacket: func(r *stack.Route) stack.PacketBufferList { @@ -681,6 +681,32 @@ func forwardedICMPv6EchoReplyChecker(t *testing.T, b []byte, src, dst tcpip.Addr checker.ICMPv6Type(header.ICMPv6EchoReply))) } +func boolToInt(v bool) uint64 { + if v { + return 1 + } + return 0 +} + +func setupDropFilter(hook stack.Hook, f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { + return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { + t.Helper() + + ipv6 := netProto == ipv6.ProtocolNumber + + ipt := s.IPTables() + filter := ipt.GetTable(stack.FilterID, ipv6) + ruleIdx := filter.BuiltinChains[hook] + filter.Rules[ruleIdx].Filter = f + filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} + // Make sure the packet is not dropped by the next rule. + filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} + if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { + t.Fatalf("ipt.ReplaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) + } + } +} + func TestForwardingHook(t *testing.T) { const ( nicID1 = 1 @@ -740,32 +766,6 @@ func TestForwardingHook(t *testing.T) { }, } - setupDropFilter := func(f stack.IPHeaderFilter) func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { - return func(t *testing.T, s *stack.Stack, netProto tcpip.NetworkProtocolNumber) { - t.Helper() - - ipv6 := netProto == ipv6.ProtocolNumber - - ipt := s.IPTables() - filter := ipt.GetTable(stack.FilterID, ipv6) - ruleIdx := filter.BuiltinChains[stack.Forward] - filter.Rules[ruleIdx].Filter = f - filter.Rules[ruleIdx].Target = &stack.DropTarget{NetworkProtocol: netProto} - // Make sure the packet is not dropped by the next rule. - filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{NetworkProtocol: netProto} - if err := ipt.ReplaceTable(stack.FilterID, filter, ipv6); err != nil { - t.Fatalf("ipt.RelaceTable(%d, _, %t): %s", stack.FilterID, ipv6, err) - } - } - } - - boolToInt := func(v bool) uint64 { - if v { - return 1 - } - return 0 - } - subTests := []struct { name string setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) @@ -779,59 +779,59 @@ func TestForwardingHook(t *testing.T) { { name: "Drop", - setupFilter: setupDropFilter(stack.IPHeaderFilter{}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{}), expectForward: false, }, { name: "Drop with input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name}), expectForward: false, }, { name: "Drop with output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name}), expectForward: false, }, { name: "Drop with input and output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: nic2Name}), expectForward: false, }, { name: "Drop with other input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other input and output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: nic2Name}), expectForward: true, }, { name: "Drop with input and other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with other input and other output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: otherNICName, OutputInterface: otherNICName}), expectForward: true, }, { name: "Drop with inverted input NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{InputInterface: nic1Name, InputInterfaceInvert: true}), expectForward: true, }, { name: "Drop with inverted output NIC filtering", - setupFilter: setupDropFilter(stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), + setupFilter: setupDropFilter(stack.Forward, stack.IPHeaderFilter{OutputInterface: nic2Name, OutputInterfaceInvert: true}), expectForward: true, }, } @@ -941,3 +941,194 @@ func TestForwardingHook(t *testing.T) { }) } } + +func TestInputHookWithLocalForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + nic1Name = "nic1" + nic2Name = "nic2" + + otherNICName = "otherNIC" + ) + + tests := []struct { + name string + netProto tcpip.NetworkProtocolNumber + rx func(*channel.Endpoint) + checker func(*testing.T, []byte) + }{ + { + name: "IPv4", + netProto: ipv4.ProtocolNumber, + rx: func(e *channel.Endpoint) { + utils.RxICMPv4EchoRequest(e, utils.RemoteIPv4Addr, utils.Ipv4Addr2.AddressWithPrefix.Address, ttl) + }, + checker: func(t *testing.T, b []byte) { + checker.IPv4(t, b, + checker.SrcAddr(utils.Ipv4Addr2.AddressWithPrefix.Address), + checker.DstAddr(utils.RemoteIPv4Addr), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4EchoReply))) + }, + }, + { + name: "IPv6", + netProto: ipv6.ProtocolNumber, + rx: func(e *channel.Endpoint) { + utils.RxICMPv6EchoRequest(e, utils.RemoteIPv6Addr, utils.Ipv6Addr2.AddressWithPrefix.Address, ttl) + }, + checker: func(t *testing.T, b []byte) { + checker.IPv6(t, b, + checker.SrcAddr(utils.Ipv6Addr2.AddressWithPrefix.Address), + checker.DstAddr(utils.RemoteIPv6Addr), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6EchoReply))) + }, + }, + } + + subTests := []struct { + name string + setupFilter func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) + expectDrop bool + }{ + { + name: "Accept", + setupFilter: func(*testing.T, *stack.Stack, tcpip.NetworkProtocolNumber) { /* no filter */ }, + expectDrop: false, + }, + + { + name: "Drop", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{}), + expectDrop: true, + }, + { + name: "Drop with input NIC filtering on arrival NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic1Name}), + expectDrop: true, + }, + { + name: "Drop with input NIC filtering on delivered NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: nic2Name}), + expectDrop: false, + }, + + { + name: "Drop with input NIC filtering on other NIC", + setupFilter: setupDropFilter(stack.Input, stack.IPHeaderFilter{InputInterface: otherNICName}), + expectDrop: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + }) + + subTest.setupFilter(t, s, test.netProto) + + e1 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID1, e1, stack.NICOptions{Name: nic1Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID1, err) + } + if err := s.AddProtocolAddress(nicID1, utils.Ipv4Addr1); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv4Addr1, err) + } + if err := s.AddProtocolAddress(nicID1, utils.Ipv6Addr1); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID1, utils.Ipv6Addr1, err) + } + + e2 := channel.New(1, header.IPv6MinimumMTU, "") + if err := s.CreateNICWithOptions(nicID2, e2, stack.NICOptions{Name: nic2Name}); err != nil { + t.Fatalf("s.CreateNICWithOptions(%d, _, _): %s", nicID2, err) + } + if err := s.AddProtocolAddress(nicID2, utils.Ipv4Addr2); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv4Addr2, err) + } + if err := s.AddProtocolAddress(nicID2, utils.Ipv6Addr2); err != nil { + t.Fatalf("s.AddProtocolAddress(%d, %#v): %s", nicID2, utils.Ipv6Addr2, err) + } + + if err := s.SetForwardingDefaultAndAllNICs(ipv4.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv4.ProtocolNumber, err) + } + if err := s.SetForwardingDefaultAndAllNICs(ipv6.ProtocolNumber, true); err != nil { + t.Fatalf("s.SetForwardingDefaultAndAllNICs(%d, true): %s", ipv6.ProtocolNumber, err) + } + + s.SetRouteTable([]tcpip.Route{ + { + Destination: header.IPv4EmptySubnet, + NIC: nicID1, + }, + { + Destination: header.IPv6EmptySubnet, + NIC: nicID1, + }, + }) + + test.rx(e1) + + ep1, err := s.GetNetworkEndpoint(nicID1, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID1, test.netProto, err) + } + ep1Stats := ep1.Stats() + ipEP1Stats, ok := ep1Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep1Stats = %T, want = stack.IPNetworkEndpointStats", ep1Stats) + } + ip1Stats := ipEP1Stats.IPStats() + + if got := ip1Stats.PacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.PacketsReceived.Value() = %d, want = 1", got) + } + if got := ip1Stats.ValidPacketsReceived.Value(); got != 1 { + t.Errorf("got ip1Stats.ValidPacketsReceived.Value() = %d, want = 1", got) + } + if got, want := ip1Stats.PacketsSent.Value(), boolToInt(!subTest.expectDrop); got != want { + t.Errorf("got ip1Stats.PacketsSent.Value() = %d, want = %d", got, want) + } + + ep2, err := s.GetNetworkEndpoint(nicID2, test.netProto) + if err != nil { + t.Fatalf("s.GetNetworkEndpoint(%d, %d): %s", nicID2, test.netProto, err) + } + ep2Stats := ep2.Stats() + ipEP2Stats, ok := ep2Stats.(stack.IPNetworkEndpointStats) + if !ok { + t.Fatalf("got ep2Stats = %T, want = stack.IPNetworkEndpointStats", ep2Stats) + } + ip2Stats := ipEP2Stats.IPStats() + if got := ip2Stats.PacketsReceived.Value(); got != 0 { + t.Errorf("got ip2Stats.PacketsReceived.Value() = %d, want = 0", got) + } + if got := ip2Stats.ValidPacketsReceived.Value(); got != 1 { + t.Errorf("got ip2Stats.ValidPacketsReceived.Value() = %d, want = 1", got) + } + if got, want := ip2Stats.IPTablesInputDropped.Value(), boolToInt(subTest.expectDrop); got != want { + t.Errorf("got ip2Stats.IPTablesInputDropped.Value() = %d, want = %d", got, want) + } + if got := ip2Stats.PacketsSent.Value(); got != 0 { + t.Errorf("got ip2Stats.PacketsSent.Value() = %d, want = 0", got) + } + + if p, ok := e1.Read(); ok == subTest.expectDrop { + t.Errorf("got e1.Read() = (%#v, %t), want = (_, %t)", p, ok, !subTest.expectDrop) + } else if !subTest.expectDrop { + test.checker(t, stack.PayloadSince(p.Pkt.NetworkHeader())) + } + if p, ok := e2.Read(); ok { + t.Errorf("got e1.Read() = (%#v, true), want = (_, false)", p) + } + }) + } + }) + } +} diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index 2c65b737d..d807b13b7 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -560,6 +560,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err } switch { + case s.flags.Contains(header.TCPFlagRst): + e.stack.Stats().DroppedPackets.Increment() + return nil + case s.flags == header.TCPFlagSyn: if e.acceptQueueIsFull() { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() @@ -611,7 +615,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() return nil - case (s.flags & header.TCPFlagAck) != 0: + case s.flags.Contains(header.TCPFlagAck): if e.acceptQueueIsFull() { // Silently drop the ack as the application can't accept // the connection at this point. The ack will be @@ -736,6 +740,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err mss: rcvdSynOptions.MSS, }) + // Requeue the segment if the ACK completing the handshake has more info + // to be procesed by the newly established endpoint. + if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && n.enqueueSegment(s) { + s.incRef() + n.newSegmentWaker.Assert() + } + // Do the delivery in a separate goroutine so // that we don't block the listen loop in case // the application is slow to accept or stops @@ -753,6 +764,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) tcpip.Err return nil default: + e.stack.Stats().DroppedPackets.Increment() return nil } } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 570e5081c..2137ebc25 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -406,11 +406,11 @@ func (h *handshake) synRcvdState(s *segment) tcpip.Error { h.ep.transitionToStateEstablishedLocked(h) - // If the segment has data then requeue it for the receiver - // to process it again once main loop is started. - if s.data.Size() > 0 { + // Requeue the segment if the ACK completing the handshake has more info + // to be procesed by the newly established endpoint. + if (s.flags.Contains(header.TCPFlagFin) || s.data.Size() > 0) && h.ep.enqueueSegment(s) { s.incRef() - h.ep.enqueueSegment(s) + h.ep.newSegmentWaker.Assert() } return nil } @@ -1474,11 +1474,19 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ return &tcpip.ErrConnectionReset{} } - if n¬ifyClose != 0 && closeTimer == nil { - if e.EndpointState() == StateFinWait2 && e.closed { + if n¬ifyClose != 0 && e.closed { + switch e.EndpointState() { + case StateEstablished: + // Perform full shutdown if the endpoint is still + // established. This can occur when notifyClose + // was asserted just before becoming established. + e.shutdownLocked(tcpip.ShutdownWrite | tcpip.ShutdownRead) + case StateFinWait2: // The socket has been closed and we are in FIN_WAIT2 // so start the FIN_WAIT2 timer. - closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + if closeTimer == nil { + closeTimer = e.stack.Clock().AfterFunc(e.tcpLingerTimeout, closeWaker.Assert) + } } } diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index e7ede7662..9bbe9bc3e 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -6238,6 +6238,54 @@ func TestPassiveFailedConnectionAttemptIncrement(t *testing.T) { } } +func TestListenDropIncrement(t *testing.T) { + c := context.New(t, defaultMTU) + defer c.Cleanup() + + stats := c.Stack().Stats() + c.Create(-1 /*epRcvBuf*/) + + if err := c.EP.Bind(tcpip.FullAddress{Addr: context.StackAddr, Port: context.StackPort}); err != nil { + t.Fatalf("Bind failed: %s", err) + } + if err := c.EP.Listen(1 /*backlog*/); err != nil { + t.Fatalf("Listen failed: %s", err) + } + + initialDropped := stats.DroppedPackets.Value() + + // Send RST, FIN segments, that are expected to be dropped by the listener. + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagRst, + }) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagFin, + }) + + // To ensure that the RST, FIN sent earlier are indeed received and ignored + // by the listener, send a SYN and wait for the SYN to be ACKd. + irs := seqnum.Value(context.TestInitialSequenceNumber) + c.SendPacket(nil, &context.Headers{ + SrcPort: context.TestPort, + DstPort: context.StackPort, + Flags: header.TCPFlagSyn, + SeqNum: irs, + }) + checker.IPv4(t, c.GetPacket(), checker.TCP(checker.SrcPort(context.StackPort), + checker.DstPort(context.TestPort), + checker.TCPFlags(header.TCPFlagAck|header.TCPFlagSyn), + checker.TCPAckNum(uint32(irs)+1), + )) + + if got, want := stats.DroppedPackets.Value(), initialDropped+2; got != want { + t.Fatalf("got stats.DroppedPackets.Value() = %d, want = %d", got, want) + } +} + func TestEndpointBindListenAcceptState(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() |