diff options
Diffstat (limited to 'pkg/sentry/seccheck')
-rw-r--r-- | pkg/sentry/seccheck/BUILD | 54 | ||||
-rw-r--r-- | pkg/sentry/seccheck/seccheck_fieldenum.go | 134 | ||||
-rw-r--r-- | pkg/sentry/seccheck/seccheck_state_autogen.go | 3 | ||||
-rw-r--r-- | pkg/sentry/seccheck/seccheck_test.go | 157 | ||||
-rw-r--r-- | pkg/sentry/seccheck/seccheck_unsafe_state_autogen.go | 3 | ||||
-rw-r--r-- | pkg/sentry/seccheck/seqatomic_checkerslice_unsafe.go | 38 |
6 files changed, 178 insertions, 211 deletions
diff --git a/pkg/sentry/seccheck/BUILD b/pkg/sentry/seccheck/BUILD deleted file mode 100644 index 943fa180d..000000000 --- a/pkg/sentry/seccheck/BUILD +++ /dev/null @@ -1,54 +0,0 @@ -load("//tools:defs.bzl", "go_library", "go_test") -load("//tools/go_fieldenum:defs.bzl", "go_fieldenum") -load("//tools/go_generics:defs.bzl", "go_template_instance") - -licenses(["notice"]) - -go_fieldenum( - name = "seccheck_fieldenum", - srcs = [ - "clone.go", - "task.go", - ], - out = "seccheck_fieldenum.go", - package = "seccheck", -) - -go_template_instance( - name = "seqatomic_checkerslice", - out = "seqatomic_checkerslice_unsafe.go", - package = "seccheck", - suffix = "CheckerSlice", - template = "//pkg/sync/seqatomic:generic_seqatomic", - types = { - "Value": "[]Checker", - }, -) - -go_library( - name = "seccheck", - srcs = [ - "clone.go", - "seccheck.go", - "seccheck_fieldenum.go", - "seqatomic_checkerslice_unsafe.go", - "task.go", - ], - visibility = ["//:sandbox"], - deps = [ - "//pkg/abi/linux", - "//pkg/context", - "//pkg/gohacks", - "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/time", - "//pkg/sync", - ], -) - -go_test( - name = "seccheck_test", - size = "small", - srcs = ["seccheck_test.go"], - library = ":seccheck", - deps = ["//pkg/context"], -) diff --git a/pkg/sentry/seccheck/seccheck_fieldenum.go b/pkg/sentry/seccheck/seccheck_fieldenum.go new file mode 100644 index 000000000..b193b2973 --- /dev/null +++ b/pkg/sentry/seccheck/seccheck_fieldenum.go @@ -0,0 +1,134 @@ +// Generated by go_fieldenum. + +package seccheck + +import "sync/atomic" + +// A CloneField represents a field in CloneInfo. +type CloneField uint + +// CloneFieldX represents CloneInfo field X. +const ( + CloneFieldCredentials CloneField = iota + CloneFieldArgs +) + +// CloneFields represents a set of fields in CloneInfo in a literal-friendly form. +// The zero value of CloneFields represents an empty set. +type CloneFields struct { + Invoker TaskFields + Credentials bool + Args bool + Created TaskFields +} + +// CloneFieldSet represents a set of fields in CloneInfo in a compact form. +// The zero value of CloneFieldSet represents an empty set. +type CloneFieldSet struct { + Invoker TaskFieldSet + Created TaskFieldSet + fields [1]uint32 +} + +// Contains returns true if f is present in the CloneFieldSet. +func (fs CloneFieldSet) Contains(f CloneField) bool { + return fs.fields[0] & (uint32(1) << uint(f)) != 0 +} + +// Add adds f to the CloneFieldSet. +func (fs *CloneFieldSet) Add(f CloneField) { + fs.fields[0] |= uint32(1) << uint(f) +} + +// Remove removes f from the CloneFieldSet. +func (fs *CloneFieldSet) Remove(f CloneField) { + fs.fields[0] &^= uint32(1) << uint(f) +} + +// Load returns a copy of the CloneFieldSet. +// Load is safe to call concurrently with AddFieldsLoadable, but not Add or Remove. +func (fs *CloneFieldSet) Load() (copied CloneFieldSet) { + copied.Invoker = fs.Invoker.Load() + copied.Created = fs.Created.Load() + copied.fields[0] = atomic.LoadUint32(&fs.fields[0]) + return +} + +// AddFieldsLoadable adds the given fields to the CloneFieldSet. +// AddFieldsLoadable is safe to call concurrently with Load, but not other methods (including other calls to AddFieldsLoadable). +func (fs *CloneFieldSet) AddFieldsLoadable(fields CloneFields) { + fs.Invoker.AddFieldsLoadable(fields.Invoker) + fs.Created.AddFieldsLoadable(fields.Created) + if fields.Credentials { + atomic.StoreUint32(&fs.fields[0], fs.fields[0] | (uint32(1) << uint(CloneFieldCredentials))) + } + if fields.Args { + atomic.StoreUint32(&fs.fields[0], fs.fields[0] | (uint32(1) << uint(CloneFieldArgs))) + } +} + +// A TaskField represents a field in TaskInfo. +type TaskField uint + +// TaskFieldX represents TaskInfo field X. +const ( + TaskFieldThreadID TaskField = iota + TaskFieldThreadStartTime + TaskFieldThreadGroupID + TaskFieldThreadGroupStartTime +) + +// TaskFields represents a set of fields in TaskInfo in a literal-friendly form. +// The zero value of TaskFields represents an empty set. +type TaskFields struct { + ThreadID bool + ThreadStartTime bool + ThreadGroupID bool + ThreadGroupStartTime bool +} + +// TaskFieldSet represents a set of fields in TaskInfo in a compact form. +// The zero value of TaskFieldSet represents an empty set. +type TaskFieldSet struct { + fields [1]uint32 +} + +// Contains returns true if f is present in the TaskFieldSet. +func (fs TaskFieldSet) Contains(f TaskField) bool { + return fs.fields[0] & (uint32(1) << uint(f)) != 0 +} + +// Add adds f to the TaskFieldSet. +func (fs *TaskFieldSet) Add(f TaskField) { + fs.fields[0] |= uint32(1) << uint(f) +} + +// Remove removes f from the TaskFieldSet. +func (fs *TaskFieldSet) Remove(f TaskField) { + fs.fields[0] &^= uint32(1) << uint(f) +} + +// Load returns a copy of the TaskFieldSet. +// Load is safe to call concurrently with AddFieldsLoadable, but not Add or Remove. +func (fs *TaskFieldSet) Load() (copied TaskFieldSet) { + copied.fields[0] = atomic.LoadUint32(&fs.fields[0]) + return +} + +// AddFieldsLoadable adds the given fields to the TaskFieldSet. +// AddFieldsLoadable is safe to call concurrently with Load, but not other methods (including other calls to AddFieldsLoadable). +func (fs *TaskFieldSet) AddFieldsLoadable(fields TaskFields) { + if fields.ThreadID { + atomic.StoreUint32(&fs.fields[0], fs.fields[0] | (uint32(1) << uint(TaskFieldThreadID))) + } + if fields.ThreadStartTime { + atomic.StoreUint32(&fs.fields[0], fs.fields[0] | (uint32(1) << uint(TaskFieldThreadStartTime))) + } + if fields.ThreadGroupID { + atomic.StoreUint32(&fs.fields[0], fs.fields[0] | (uint32(1) << uint(TaskFieldThreadGroupID))) + } + if fields.ThreadGroupStartTime { + atomic.StoreUint32(&fs.fields[0], fs.fields[0] | (uint32(1) << uint(TaskFieldThreadGroupStartTime))) + } +} + diff --git a/pkg/sentry/seccheck/seccheck_state_autogen.go b/pkg/sentry/seccheck/seccheck_state_autogen.go new file mode 100644 index 000000000..2fa2e9787 --- /dev/null +++ b/pkg/sentry/seccheck/seccheck_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package seccheck diff --git a/pkg/sentry/seccheck/seccheck_test.go b/pkg/sentry/seccheck/seccheck_test.go deleted file mode 100644 index 687810d18..000000000 --- a/pkg/sentry/seccheck/seccheck_test.go +++ /dev/null @@ -1,157 +0,0 @@ -// Copyright 2021 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package seccheck - -import ( - "errors" - "testing" - - "gvisor.dev/gvisor/pkg/context" -) - -type testChecker struct { - CheckerDefaults - - onClone func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error -} - -// Clone implements Checker.Clone. -func (c *testChecker) Clone(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { - if c.onClone == nil { - return nil - } - return c.onClone(ctx, mask, info) -} - -func TestNoChecker(t *testing.T) { - var s state - if s.Enabled(PointClone) { - t.Errorf("Enabled(PointClone): got true, wanted false") - } -} - -func TestCheckerNotRegisteredForPoint(t *testing.T) { - var s state - s.AppendChecker(&testChecker{}, &CheckerReq{}) - if s.Enabled(PointClone) { - t.Errorf("Enabled(PointClone): got true, wanted false") - } -} - -func TestCheckerRegistered(t *testing.T) { - var s state - checkerCalled := false - s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { - checkerCalled = true - return nil - }}, &CheckerReq{ - Points: []Point{PointClone}, - Clone: CloneFields{ - Credentials: true, - }, - }) - - if !s.Enabled(PointClone) { - t.Errorf("Enabled(PointClone): got false, wanted true") - } - if !s.CloneReq().Contains(CloneFieldCredentials) { - t.Errorf("CloneReq().Contains(CloneFieldCredentials): got false, wanted true") - } - if err := s.Clone(context.Background(), CloneFieldSet{}, &CloneInfo{}); err != nil { - t.Errorf("Clone(): got %v, wanted nil", err) - } - if !checkerCalled { - t.Errorf("Clone() did not call Checker.Clone()") - } -} - -func TestMultipleCheckersRegistered(t *testing.T) { - var s state - checkersCalled := [2]bool{} - s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { - checkersCalled[0] = true - return nil - }}, &CheckerReq{ - Points: []Point{PointClone}, - Clone: CloneFields{ - Args: true, - }, - }) - s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { - checkersCalled[1] = true - return nil - }}, &CheckerReq{ - Points: []Point{PointClone}, - Clone: CloneFields{ - Created: TaskFields{ - ThreadID: true, - }, - }, - }) - - if !s.Enabled(PointClone) { - t.Errorf("Enabled(PointClone): got false, wanted true") - } - // CloneReq() should return the union of requested fields from all calls to - // AppendChecker. - req := s.CloneReq() - if !req.Contains(CloneFieldArgs) { - t.Errorf("req.Contains(CloneFieldArgs): got false, wanted true") - } - if !req.Created.Contains(TaskFieldThreadID) { - t.Errorf("req.Created.Contains(TaskFieldThreadID): got false, wanted true") - } - if err := s.Clone(context.Background(), CloneFieldSet{}, &CloneInfo{}); err != nil { - t.Errorf("Clone(): got %v, wanted nil", err) - } - for i := range checkersCalled { - if !checkersCalled[i] { - t.Errorf("Clone() did not call Checker.Clone() index %d", i) - } - } -} - -func TestCheckpointReturnsFirstCheckerError(t *testing.T) { - errFirstChecker := errors.New("first Checker error") - errSecondChecker := errors.New("second Checker error") - - var s state - checkersCalled := [2]bool{} - s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { - checkersCalled[0] = true - return errFirstChecker - }}, &CheckerReq{ - Points: []Point{PointClone}, - }) - s.AppendChecker(&testChecker{onClone: func(ctx context.Context, mask CloneFieldSet, info CloneInfo) error { - checkersCalled[1] = true - return errSecondChecker - }}, &CheckerReq{ - Points: []Point{PointClone}, - }) - - if !s.Enabled(PointClone) { - t.Errorf("Enabled(PointClone): got false, wanted true") - } - if err := s.Clone(context.Background(), CloneFieldSet{}, &CloneInfo{}); err != errFirstChecker { - t.Errorf("Clone(): got %v, wanted %v", err, errFirstChecker) - } - if !checkersCalled[0] { - t.Errorf("Clone() did not call first Checker") - } - if checkersCalled[1] { - t.Errorf("Clone() called second Checker") - } -} diff --git a/pkg/sentry/seccheck/seccheck_unsafe_state_autogen.go b/pkg/sentry/seccheck/seccheck_unsafe_state_autogen.go new file mode 100644 index 000000000..2fa2e9787 --- /dev/null +++ b/pkg/sentry/seccheck/seccheck_unsafe_state_autogen.go @@ -0,0 +1,3 @@ +// automatically generated by stateify. + +package seccheck diff --git a/pkg/sentry/seccheck/seqatomic_checkerslice_unsafe.go b/pkg/sentry/seccheck/seqatomic_checkerslice_unsafe.go new file mode 100644 index 000000000..05a6c6eee --- /dev/null +++ b/pkg/sentry/seccheck/seqatomic_checkerslice_unsafe.go @@ -0,0 +1,38 @@ +package seccheck + +import ( + "unsafe" + + "gvisor.dev/gvisor/pkg/gohacks" + "gvisor.dev/gvisor/pkg/sync" +) + +// SeqAtomicLoad returns a copy of *ptr, ensuring that the read does not race +// with any writer critical sections in seq. +// +//go:nosplit +func SeqAtomicLoadCheckerSlice(seq *sync.SeqCount, ptr *[]Checker) []Checker { + for { + if val, ok := SeqAtomicTryLoadCheckerSlice(seq, seq.BeginRead(), ptr); ok { + return val + } + } +} + +// SeqAtomicTryLoad returns a copy of *ptr while in a reader critical section +// in seq initiated by a call to seq.BeginRead() that returned epoch. If the +// read would race with a writer critical section, SeqAtomicTryLoad returns +// (unspecified, false). +// +//go:nosplit +func SeqAtomicTryLoadCheckerSlice(seq *sync.SeqCount, epoch sync.SeqCountEpoch, ptr *[]Checker) (val []Checker, ok bool) { + if sync.RaceEnabled { + + gohacks.Memmove(unsafe.Pointer(&val), unsafe.Pointer(ptr), unsafe.Sizeof(val)) + } else { + + val = *ptr + } + ok = seq.ReadOk(epoch) + return +} |