summaryrefslogtreecommitdiffhomepage
diff options
context:
space:
mode:
-rwxr-xr-xdebian/postinst.sh2
-rw-r--r--images/default/Dockerfile2
-rw-r--r--nogo.yaml11
-rw-r--r--pkg/bitmap/BUILD16
-rw-r--r--pkg/bitmap/bitmap.go234
-rw-r--r--pkg/bitmap/bitmap_test.go306
-rw-r--r--pkg/gohacks/gohacks_unsafe.go8
-rw-r--r--pkg/procid/procid_amd64.s4
-rw-r--r--pkg/procid/procid_arm64.s4
-rw-r--r--pkg/sentry/control/BUILD4
-rw-r--r--pkg/sentry/control/fs.go93
-rw-r--r--pkg/sentry/control/lifecycle.go36
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cgroupfs.go5
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go28
-rw-r--r--pkg/sentry/fsimpl/proc/task.go2
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go111
-rw-r--r--pkg/sentry/inet/inet.go3
-rw-r--r--pkg/sentry/inet/test_stack.go50
-rw-r--r--pkg/sentry/kernel/BUILD1
-rw-r--r--pkg/sentry/kernel/fd_table.go196
-rw-r--r--pkg/sentry/kernel/fd_table_unsafe.go11
-rw-r--r--pkg/sentry/kernel/msgqueue/msgqueue.go278
-rw-r--r--pkg/sentry/platform/kvm/bluepill_fault.go2
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go10
-rw-r--r--pkg/sentry/platform/kvm/machine.go13
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go8
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_unsafe.go8
-rw-r--r--pkg/sentry/socket/hostinet/stack.go29
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go43
-rw-r--r--pkg/sentry/socket/netstack/stack.go6
-rw-r--r--pkg/sentry/syscalls/linux/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_msgqueue.go85
-rw-r--r--pkg/sentry/syscalls/linux/sys_prctl.go16
-rw-r--r--pkg/sync/goyield_unsafe.go8
-rw-r--r--pkg/sync/runtime_unsafe.go13
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go32
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go21
-rw-r--r--pkg/tcpip/link/fdbased/mmap_unsafe.go9
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go53
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go7
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go7
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go55
-rw-r--r--pkg/tcpip/link/sniffer/pcap.go56
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go33
-rw-r--r--pkg/tcpip/network/ipv4/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go3
-rw-r--r--pkg/tcpip/stack/stack.go3
-rw-r--r--pkg/tcpip/stack/stack_test.go77
-rw-r--r--runsc/boot/controller.go39
-rw-r--r--runsc/cmd/chroot.go25
-rw-r--r--runsc/cmd/debug.go8
-rw-r--r--runsc/cmd/gofer.go14
-rw-r--r--runsc/container/container.go6
-rw-r--r--runsc/container/container_test.go82
-rw-r--r--runsc/sandbox/sandbox.go22
-rw-r--r--test/e2e/integration_test.go46
-rw-r--r--test/perf/BUILD28
-rw-r--r--test/perf/linux/BUILD73
-rw-r--r--test/perf/linux/dup_benchmark.cc55
-rw-r--r--test/perf/linux/verity_open_benchmark.cc2
-rw-r--r--test/perf/linux/verity_open_read_close_benchmark.cc75
-rw-r--r--test/perf/linux/verity_randread_benchmark.cc108
-rw-r--r--test/perf/linux/verity_read_benchmark.cc69
-rw-r--r--test/root/chroot_test.go10
-rw-r--r--test/syscalls/linux/BUILD4
-rw-r--r--test/syscalls/linux/dup.cc41
-rw-r--r--test/syscalls/linux/lseek.cc3
-rw-r--r--test/syscalls/linux/memfd.cc5
-rw-r--r--test/syscalls/linux/msgqueue.cc572
-rw-r--r--test/syscalls/linux/prctl.cc6
-rw-r--r--test/syscalls/linux/proc.cc42
-rw-r--r--test/syscalls/linux/socket_netlink_route.cc116
-rw-r--r--test/syscalls/linux/stat.cc2
-rw-r--r--test/syscalls/linux/statfs.cc2
-rw-r--r--tools/checklinkname/BUILD16
-rw-r--r--tools/checklinkname/README.md54
-rw-r--r--tools/checklinkname/check_linkname.go229
-rw-r--r--tools/checklinkname/known.go110
-rw-r--r--tools/checklinkname/test/BUILD9
-rw-r--r--tools/checklinkname/test/test_unsafe.go34
-rw-r--r--tools/constraintutil/BUILD18
-rw-r--r--tools/constraintutil/constraintutil.go169
-rw-r--r--tools/constraintutil/constraintutil_test.go138
-rw-r--r--tools/go_generics/go_merge/BUILD2
-rw-r--r--tools/go_generics/go_merge/main.go13
-rw-r--r--tools/go_marshal/gomarshal/BUILD2
-rw-r--r--tools/go_marshal/gomarshal/generator.go33
-rw-r--r--tools/go_stateify/BUILD2
-rw-r--r--tools/go_stateify/main.go11
-rw-r--r--tools/nogo/BUILD1
-rw-r--r--tools/nogo/analyzers.go2
-rw-r--r--tools/nogo/filter/main.go2
-rw-r--r--tools/tags/BUILD11
-rw-r--r--tools/tags/tags.go89
95 files changed, 3903 insertions, 509 deletions
diff --git a/debian/postinst.sh b/debian/postinst.sh
index 6a326f823..b387b9f22 100755
--- a/debian/postinst.sh
+++ b/debian/postinst.sh
@@ -22,7 +22,7 @@ fi
if [ -f /etc/docker/daemon.json ]; then
runsc install
if systemctl is-active -q docker; then
- systemctl restart docker || echo "unable to restart docker; you must do so manually." >&2
+ systemctl reload docker || echo "unable to reload docker; you must do so manually." >&2
fi
fi
diff --git a/images/default/Dockerfile b/images/default/Dockerfile
index 5f652f2c3..4384d6271 100644
--- a/images/default/Dockerfile
+++ b/images/default/Dockerfile
@@ -15,7 +15,7 @@ RUN add-apt-repository \
"deb https://download.docker.com/linux/ubuntu \
$(lsb_release -cs) \
stable"
-RUN apt-get install docker-ce-cli
+RUN apt-get -y install docker-ce-cli
# Install gcloud.
RUN curl https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-289.0.0-linux-x86_64.tar.gz | \
diff --git a/nogo.yaml b/nogo.yaml
index d7a3ad956..e7d1741eb 100644
--- a/nogo.yaml
+++ b/nogo.yaml
@@ -150,6 +150,17 @@ analyzers:
external: # Enabled.
checkescape:
external: # Enabled.
+ checklinkname:
+ external: # Enabled.
+ suppress:
+ # We don't care to check every single linkname in the Go standard
+ # library. Suppress findings about stdlib linkname targets we haven't
+ # described in checklinkname.
+ #
+ # Note that we _do_ want to check the signature of the known linkname
+ # targets in the standard library, so we still need to run
+ # checklinkname on stdlib generally.
+ - "linkname to unknown symbol"
SA4016:
internal:
exclude:
diff --git a/pkg/bitmap/BUILD b/pkg/bitmap/BUILD
new file mode 100644
index 000000000..0f1e7006d
--- /dev/null
+++ b/pkg/bitmap/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "bitmap",
+ srcs = ["bitmap.go"],
+ visibility = ["//:sandbox"],
+)
+
+go_test(
+ name = "bitmap_test",
+ size = "small",
+ srcs = ["bitmap_test.go"],
+ library = ":bitmap",
+)
diff --git a/pkg/bitmap/bitmap.go b/pkg/bitmap/bitmap.go
new file mode 100644
index 000000000..12d2fc2b8
--- /dev/null
+++ b/pkg/bitmap/bitmap.go
@@ -0,0 +1,234 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package bitmap provides the implementation of bitmap.
+package bitmap
+
+import (
+ "math"
+ "math/bits"
+)
+
+// Bitmap implements an efficient bitmap.
+//
+// +stateify savable
+type Bitmap struct {
+ // numOnes is the number of ones in the bitmap.
+ numOnes uint32
+
+ // bitBlock holds the bits. The type of bitBlock is uint64 which means
+ // each number in bitBlock contains 64 entries.
+ bitBlock []uint64
+}
+
+// New create a new empty Bitmap.
+func New(size uint32) Bitmap {
+ b := Bitmap{}
+ bSize := (size + 63) / 64
+ b.bitBlock = make([]uint64, bSize)
+ return b
+}
+
+// IsEmpty verifies whether the Bitmap is empty.
+func (b *Bitmap) IsEmpty() bool {
+ return b.numOnes == 0
+}
+
+// Minimum return the smallest value in the Bitmap.
+func (b *Bitmap) Minimum() uint32 {
+ for i := 0; i < len(b.bitBlock); i++ {
+ if w := b.bitBlock[i]; w != 0 {
+ r := bits.TrailingZeros64(w)
+ return uint32(r + i*64)
+ }
+ }
+ return math.MaxInt32
+}
+
+// FirstZero returns the first unset bit from the range [start, ).
+func (b *Bitmap) FirstZero(start uint32) uint32 {
+ i, nbit := int(start/64), start%64
+ n := len(b.bitBlock)
+ if i >= n {
+ return math.MaxInt32
+ }
+ w := b.bitBlock[i] | ((1 << nbit) - 1)
+ for {
+ if w != ^uint64(0) {
+ r := bits.TrailingZeros64(^w)
+ return uint32(r + i*64)
+ }
+ i++
+ if i == n {
+ break
+ }
+ w = b.bitBlock[i]
+ }
+ return math.MaxInt32
+}
+
+// Maximum return the largest value in the Bitmap.
+func (b *Bitmap) Maximum() uint32 {
+ for i := len(b.bitBlock) - 1; i >= 0; i-- {
+ if w := b.bitBlock[i]; w != 0 {
+ r := bits.LeadingZeros64(w)
+ return uint32(i*64 + 63 - r)
+ }
+ }
+ return uint32(0)
+}
+
+// Add add i to the Bitmap.
+func (b *Bitmap) Add(i uint32) {
+ blockNum, mask := i/64, uint64(1)<<(i%64)
+ // if blockNum is out of range, extend b.bitBlock
+ if x, y := int(blockNum), len(b.bitBlock); x >= y {
+ b.bitBlock = append(b.bitBlock, make([]uint64, x-y+1)...)
+ }
+ oldBlock := b.bitBlock[blockNum]
+ newBlock := oldBlock | mask
+ if oldBlock != newBlock {
+ b.bitBlock[blockNum] = newBlock
+ b.numOnes++
+ }
+}
+
+// Remove i from the Bitmap.
+func (b *Bitmap) Remove(i uint32) {
+ blockNum, mask := i/64, uint64(1)<<(i%64)
+ oldBlock := b.bitBlock[blockNum]
+ newBlock := oldBlock &^ mask
+ if oldBlock != newBlock {
+ b.bitBlock[blockNum] = newBlock
+ b.numOnes--
+ }
+}
+
+// Clone the Bitmap.
+func (b *Bitmap) Clone() Bitmap {
+ bitmap := Bitmap{b.numOnes, make([]uint64, len(b.bitBlock))}
+ copy(bitmap.bitBlock, b.bitBlock[:])
+ return bitmap
+}
+
+// countOnesForBlocks count all 1 bits within b.bitBlock of begin and that of end.
+// The begin block and end block are inclusive.
+func (b *Bitmap) countOnesForBlocks(begin, end uint32) uint64 {
+ ones := uint64(0)
+ beginBlock := begin / 64
+ endBlock := end / 64
+ for i := beginBlock; i <= endBlock; i++ {
+ ones += uint64(bits.OnesCount64(b.bitBlock[i]))
+ }
+ return ones
+}
+
+// countOnesForAllBlocks count all 1 bits in b.bitBlock.
+func (b *Bitmap) countOnesForAllBlocks() uint64 {
+ ones := uint64(0)
+ for i := 0; i < len(b.bitBlock); i++ {
+ ones += uint64(bits.OnesCount64(b.bitBlock[i]))
+ }
+ return ones
+}
+
+// flipRange flip the bits within range (begin and end). begin is inclusive and end is exclusive.
+func (b *Bitmap) flipRange(begin, end uint32) {
+ end--
+ beginBlock := begin / 64
+ endBlock := end / 64
+ if beginBlock == endBlock {
+ b.bitBlock[endBlock] ^= ((^uint64(0) << uint(begin%64)) & ((uint64(1) << (uint(end)%64 + 1)) - 1))
+ } else {
+ b.bitBlock[beginBlock] ^= ^(^uint64(0) << uint(begin%64))
+ for i := beginBlock; i < endBlock; i++ {
+ b.bitBlock[i] = ^b.bitBlock[i]
+ }
+ b.bitBlock[endBlock] ^= ((uint64(1) << (uint(end)%64 + 1)) - 1)
+ }
+}
+
+// clearRange clear the bits within range (begin and end). begin is inclusive and end is exclusive.
+func (b *Bitmap) clearRange(begin, end uint32) {
+ end--
+ beginBlock := begin / 64
+ endBlock := end / 64
+ if beginBlock == endBlock {
+ b.bitBlock[beginBlock] &= (((uint64(1) << uint(begin%64)) - 1) | ^((uint64(1) << (uint(end)%64 + 1)) - 1))
+ } else {
+ b.bitBlock[beginBlock] &= ((uint64(1) << uint(begin%64)) - 1)
+ for i := beginBlock + 1; i < endBlock; i++ {
+ b.bitBlock[i] &= ^b.bitBlock[i]
+ }
+ b.bitBlock[endBlock] &= ^((uint64(1) << (uint(end)%64 + 1)) - 1)
+ }
+}
+
+// ClearRange clear bits within range (begin and end) for the Bitmap. begin is inclusive and end is exclusive.
+func (b *Bitmap) ClearRange(begin, end uint32) {
+ blockRange := end/64 - begin/64
+ // When the number of cleared blocks is larger than half of the length of b.bitBlock,
+ // counting 1s for the entire bitmap has better performance.
+ if blockRange > uint32(len(b.bitBlock)/2) {
+ b.clearRange(begin, end)
+ b.numOnes = uint32(b.countOnesForAllBlocks())
+ } else {
+ oldRangeOnes := b.countOnesForBlocks(begin, end)
+ b.clearRange(begin, end)
+ newRangeOnes := b.countOnesForBlocks(begin, end)
+ b.numOnes += uint32(newRangeOnes - oldRangeOnes)
+ }
+}
+
+// FlipRange flip bits within range (begin and end) for the Bitmap. begin is inclusive and end is exclusive.
+func (b *Bitmap) FlipRange(begin, end uint32) {
+ blockRange := end/64 - begin/64
+ // When the number of flipped blocks is larger than half of the length of b.bitBlock,
+ // counting 1s for the entire bitmap has better performance.
+ if blockRange > uint32(len(b.bitBlock)/2) {
+ b.flipRange(begin, end)
+ b.numOnes = uint32(b.countOnesForAllBlocks())
+ } else {
+ oldRangeOnes := b.countOnesForBlocks(begin, end)
+ b.flipRange(begin, end)
+ newRangeOnes := b.countOnesForBlocks(begin, end)
+ b.numOnes += uint32(newRangeOnes - oldRangeOnes)
+ }
+}
+
+// ToSlice transform the Bitmap into slice. For example, a bitmap of [0, 1, 0, 1]
+// will return the slice [1, 3].
+func (b *Bitmap) ToSlice() []uint32 {
+ bitmapSlice := make([]uint32, 0, b.numOnes)
+ // base is the start number of a bitBlock
+ base := 0
+ for i := 0; i < len(b.bitBlock); i++ {
+ bitBlock := b.bitBlock[i]
+ // Iterate through all the numbers held by this bit block.
+ for bitBlock != 0 {
+ // Extract the lowest set 1 bit.
+ j := bitBlock & -bitBlock
+ // Interpret the bit as the in32 number it represents and add it to result.
+ bitmapSlice = append(bitmapSlice, uint32((base + int(bits.OnesCount64(j-1)))))
+ bitBlock ^= j
+ }
+ base += 64
+ }
+ return bitmapSlice
+}
+
+// GetNumOnes return the the number of ones in the Bitmap.
+func (b *Bitmap) GetNumOnes() uint32 {
+ return b.numOnes
+}
diff --git a/pkg/bitmap/bitmap_test.go b/pkg/bitmap/bitmap_test.go
new file mode 100644
index 000000000..76ebd779f
--- /dev/null
+++ b/pkg/bitmap/bitmap_test.go
@@ -0,0 +1,306 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package bitmap
+
+import (
+ "math"
+ "reflect"
+ "testing"
+)
+
+// generateFilledSlice generates a slice fill with numbers.
+func generateFilledSlice(min, max, length int) []uint32 {
+ if max == min {
+ return []uint32{uint32(min)}
+ }
+ if length > (max - min) {
+ return nil
+ }
+ randSlice := make([]uint32, length)
+ if length != 0 {
+ rangeNum := uint32((max - min) / length)
+ randSlice[0], randSlice[length-1] = uint32(min), uint32(max)
+ for i := 1; i < length-1; i++ {
+ randSlice[i] = randSlice[i-1] + rangeNum
+ }
+ }
+ return randSlice
+}
+
+// generateFilledBitmap generates a Bitmap filled with fillNum of numbers,
+// and returns the slice and bitmap.
+func generateFilledBitmap(min, max, fillNum int) ([]uint32, Bitmap) {
+ bitmap := New(uint32(max))
+ randSlice := generateFilledSlice(min, max, fillNum)
+ for i := 0; i < fillNum; i++ {
+ bitmap.Add(randSlice[i])
+ }
+ return randSlice, bitmap
+}
+
+func TestNewBitmap(t *testing.T) {
+ tests := []struct {
+ name string
+ size int
+ expectSize int
+ }{
+ {"length 1", 1, 1},
+ {"length 1024", 1024, 16},
+ {"length 1025", 1025, 17},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ if bitmap := New(uint32(tt.size)); len(bitmap.bitBlock) != tt.expectSize {
+ t.Errorf("New created bitmap with %v, bitBlock size: %d, wanted: %d", tt.name, len(bitmap.bitBlock), tt.expectSize)
+ }
+ })
+ }
+}
+
+func TestAdd(t *testing.T) {
+ tests := []struct {
+ name string
+ bitmapSize int
+ addNum int
+ }{
+ {"Add with null bitmap.bitBlock", 0, 10},
+ {"Add without extending bitBlock", 64, 10},
+ {"Add without extending bitblock with margin number", 63, 64},
+ {"Add with extended one block", 1024, 1025},
+ {"Add with extended more then one block", 1024, 2048},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ bitmap := New(uint32(tt.bitmapSize))
+ bitmap.Add(uint32(tt.addNum))
+ bitmapSlice := bitmap.ToSlice()
+ if bitmapSlice[0] != uint32(tt.addNum) {
+ t.Errorf("%v, get number: %d, wanted: %d.", tt.name, bitmapSlice[0], tt.addNum)
+ }
+ })
+ }
+}
+
+func TestRemove(t *testing.T) {
+ bitmap := New(uint32(1024))
+ firstSlice := generateFilledSlice(0, 511, 50)
+ secondSlice := generateFilledSlice(512, 1024, 50)
+ for i := 0; i < 50; i++ {
+ bitmap.Add(firstSlice[i])
+ bitmap.Add(secondSlice[i])
+ }
+
+ for i := 0; i < 50; i++ {
+ bitmap.Remove(firstSlice[i])
+ }
+ bitmapSlice := bitmap.ToSlice()
+ if !reflect.DeepEqual(bitmapSlice, secondSlice) {
+ t.Errorf("After Remove() firstSlice, remained slice: %v, wanted: %v", bitmapSlice, secondSlice)
+ }
+
+ for i := 0; i < 50; i++ {
+ bitmap.Remove(secondSlice[i])
+ }
+ bitmapSlice = bitmap.ToSlice()
+ emptySlice := make([]uint32, 0)
+ if !reflect.DeepEqual(bitmapSlice, emptySlice) {
+ t.Errorf("After Remove secondSlice, remained slice: %v, wanted: %v", bitmapSlice, emptySlice)
+ }
+
+}
+
+// Verify flip bits within one bitBlock, one bit and bits cross multi bitBlocks.
+func TestFlipRange(t *testing.T) {
+ tests := []struct {
+ name string
+ flipRangeMin int
+ flipRangeMax int
+ filledSliceLen int
+ }{
+ {"Flip one number in bitmap", 77, 77, 1},
+ {"Flip numbers within one bitBlocks", 5, 60, 20},
+ {"Flip numbers that cross multi bitBlocks", 20, 1000, 300},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ fillSlice, bitmap := generateFilledBitmap(tt.flipRangeMin, tt.flipRangeMax, tt.filledSliceLen)
+ flipFillSlice := make([]uint32, 0)
+ for i, j := tt.flipRangeMin, 0; i <= tt.flipRangeMax; i++ {
+ if uint32(i) != fillSlice[j] {
+ flipFillSlice = append(flipFillSlice, uint32(i))
+ } else {
+ j++
+ }
+ }
+
+ bitmap.FlipRange(uint32(tt.flipRangeMin), uint32(tt.flipRangeMax+1))
+ flipBitmapSlice := bitmap.ToSlice()
+ if !reflect.DeepEqual(flipFillSlice, flipBitmapSlice) {
+ t.Errorf("%v, flipped slice: %v, wanted: %v", tt.name, flipBitmapSlice, flipFillSlice)
+ }
+ })
+ }
+}
+
+// Verify clear bits within one bitBlock, one bit and bits cross multi bitBlocks.
+func TestClearRange(t *testing.T) {
+ tests := []struct {
+ name string
+ clearRangeMin int
+ clearRangeMax int
+ bitmapSize int
+ }{
+ {"ClearRange clear one number", 5, 5, 64},
+ {"ClearRange clear numbers within one bitBlock", 4, 61, 64},
+ {"ClearRange clear numbers cross multi bitBlocks", 20, 254, 512},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ bitmap := New(uint32(tt.bitmapSize))
+ bitmap.FlipRange(uint32(0), uint32(tt.bitmapSize))
+ bitmap.ClearRange(uint32(tt.clearRangeMin), uint32(tt.clearRangeMax+1))
+ clearedBitmapSlice := bitmap.ToSlice()
+ clearedSlice := make([]uint32, 0)
+ for i := 0; i < tt.bitmapSize; i++ {
+ if i < tt.clearRangeMin || i > tt.clearRangeMax {
+ clearedSlice = append(clearedSlice, uint32(i))
+ }
+ }
+ if !reflect.DeepEqual(clearedSlice, clearedBitmapSlice) {
+ t.Errorf("%v, cleared slice: %v, wanted: %v", tt.name, clearedBitmapSlice, clearedSlice)
+ }
+ })
+
+ }
+}
+
+func TestMinimum(t *testing.T) {
+ randSlice, bitmap := generateFilledBitmap(0, 1024, 200)
+ min := bitmap.Minimum()
+ if min != randSlice[0] {
+ t.Errorf("Minimum() returns: %v, wanted: %v", min, randSlice[0])
+ }
+
+ bitmap.ClearRange(uint32(0), uint32(200))
+ min = bitmap.Minimum()
+ bitmapSlice := bitmap.ToSlice()
+ if min != bitmapSlice[0] {
+ t.Errorf("After ClearRange, Minimum() returns: %v, wanted: %v", min, bitmapSlice[0])
+ }
+
+ bitmap.FlipRange(uint32(2), uint32(11))
+ min = bitmap.Minimum()
+ bitmapSlice = bitmap.ToSlice()
+ if min != bitmapSlice[0] {
+ t.Errorf("After Flip, Minimum() returns: %v, wanted: %v", min, bitmapSlice[0])
+ }
+}
+
+func TestMaximum(t *testing.T) {
+ randSlice, bitmap := generateFilledBitmap(0, 1024, 200)
+ max := bitmap.Maximum()
+ if max != randSlice[len(randSlice)-1] {
+ t.Errorf("Maximum() returns: %v, wanted: %v", max, randSlice[len(randSlice)-1])
+ }
+
+ bitmap.ClearRange(uint32(1000), uint32(1025))
+ max = bitmap.Maximum()
+ bitmapSlice := bitmap.ToSlice()
+ if max != bitmapSlice[len(bitmapSlice)-1] {
+ t.Errorf("After ClearRange, Maximum() returns: %v, wanted: %v", max, bitmapSlice[len(bitmapSlice)-1])
+ }
+
+ bitmap.FlipRange(uint32(1001), uint32(1021))
+ max = bitmap.Maximum()
+ bitmapSlice = bitmap.ToSlice()
+ if max != bitmapSlice[len(bitmapSlice)-1] {
+ t.Errorf("After Flip, Maximum() returns: %v, wanted: %v", max, bitmapSlice[len(bitmapSlice)-1])
+ }
+}
+
+func TestBitmapNumOnes(t *testing.T) {
+ randSlice, bitmap := generateFilledBitmap(0, 1024, 200)
+ bitmapOnes := bitmap.GetNumOnes()
+ if bitmapOnes != uint32(200) {
+ t.Errorf("GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 200)
+ }
+ // Remove 10 numbers.
+ for i := 5; i < 15; i++ {
+ bitmap.Remove(randSlice[i])
+ }
+ bitmapOnes = bitmap.GetNumOnes()
+ if bitmapOnes != uint32(190) {
+ t.Errorf("After Remove 10 number, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 190)
+ }
+ // Remove the 10 number again, the length supposed not change.
+ for i := 5; i < 15; i++ {
+ bitmap.Remove(randSlice[i])
+ }
+ bitmapOnes = bitmap.GetNumOnes()
+ if bitmapOnes != uint32(190) {
+ t.Errorf("After Remove the 10 number again, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 190)
+ }
+
+ // Add 10 number.
+ for i := 1080; i < 1090; i++ {
+ bitmap.Add(uint32(i))
+ }
+ bitmapOnes = bitmap.GetNumOnes()
+ if bitmapOnes != uint32(200) {
+ t.Errorf("After Add 10 number, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 200)
+ }
+
+ // Add the 10 number again, the length supposed not change.
+ for i := 1080; i < 1090; i++ {
+ bitmap.Add(uint32(i))
+ }
+ bitmapOnes = bitmap.GetNumOnes()
+ if bitmapOnes != uint32(200) {
+ t.Errorf("After Add the 10 number again, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 200)
+ }
+
+ // Flip 10 bits from 0 to 1.
+ bitmap.FlipRange(uint32(1025), uint32(1035))
+ bitmapOnes = bitmap.GetNumOnes()
+ if bitmapOnes != uint32(210) {
+ t.Errorf("After Flip, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 210)
+ }
+
+ // ClearRange numbers range from [0, 1025).
+ bitmap.ClearRange(uint32(0), uint32(1025))
+ bitmapOnes = bitmap.GetNumOnes()
+ if bitmapOnes != uint32(20) {
+ t.Errorf("After ClearRange, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 20)
+ }
+}
+
+func TestFirstZero(t *testing.T) {
+ bitmap := New(uint32(1000))
+ bitmap.FlipRange(200, 400)
+ for i, j := range map[uint32]uint32{0: 0, 201: 400, 200: 400, 199: 199, 400: 400, 10000: math.MaxInt32} {
+ v := bitmap.FirstZero(i)
+ if v != j {
+ t.Errorf("Minimum() returns: %v, wanted: %v", v, j)
+ }
+ }
+}
diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go
index 09fc14787..bd8ceba19 100644
--- a/pkg/gohacks/gohacks_unsafe.go
+++ b/pkg/gohacks/gohacks_unsafe.go
@@ -15,7 +15,13 @@
//go:build go1.13 && !go1.18
// +build go1.13,!go1.18
-// Check type signatures when updating Go version.
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
+
+// Check type signatures and Noescape when updating Go version.
+//
+// TODO(b/165820485): add these checks to checklinkname.
// Package gohacks contains utilities for subverting the Go compiler.
package gohacks
diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s
index b5bbfff90..74a8de42c 100644
--- a/pkg/procid/procid_amd64.s
+++ b/pkg/procid/procid_amd64.s
@@ -15,6 +15,10 @@
//go:build amd64 && go1.8 && !go1.18 && go1.1
// +build amd64,go1.8,!go1.18,go1.1
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
+
#include "textflag.h"
TEXT ·Current(SB),NOSPLIT,$0-8
diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s
index 772d96289..48182c4a9 100644
--- a/pkg/procid/procid_arm64.s
+++ b/pkg/procid/procid_arm64.s
@@ -15,6 +15,10 @@
//go:build arm64 && go1.8 && !go1.18 && go1.1
// +build arm64,go1.8,!go1.18,go1.1
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
+
#include "textflag.h"
TEXT ·Current(SB),NOSPLIT,$0-8
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
index deaf5fa23..7ee237c9f 100644
--- a/pkg/sentry/control/BUILD
+++ b/pkg/sentry/control/BUILD
@@ -6,6 +6,8 @@ go_library(
name = "control",
srcs = [
"control.go",
+ "fs.go",
+ "lifecycle.go",
"logging.go",
"pprof.go",
"proc.go",
@@ -16,6 +18,7 @@ go_library(
],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/fd",
"//pkg/log",
"//pkg/sentry/fdimport",
@@ -35,6 +38,7 @@ go_library(
"//pkg/sync",
"//pkg/tcpip/link/sniffer",
"//pkg/urpc",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/control/fs.go b/pkg/sentry/control/fs.go
new file mode 100644
index 000000000..d19b21f2d
--- /dev/null
+++ b/pkg/sentry/control/fs.go
@@ -0,0 +1,93 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "fmt"
+ "io"
+ "os"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/urpc"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// CatOpts contains options for the Cat RPC call.
+type CatOpts struct {
+ // Files are the filesystem paths for the files to cat.
+ Files []string `json:"files"`
+
+ // FilePayload contains the destination for output.
+ urpc.FilePayload
+}
+
+// Fs includes fs-related functions.
+type Fs struct {
+ Kernel *kernel.Kernel
+}
+
+// Cat is a RPC stub which prints out and returns the content of the files.
+func (f *Fs) Cat(o *CatOpts, _ *struct{}) error {
+ // Create an output stream.
+ if len(o.FilePayload.Files) != 1 {
+ return ErrInvalidFiles
+ }
+
+ output := o.FilePayload.Files[0]
+ for _, file := range o.Files {
+ if err := cat(f.Kernel, file, output); err != nil {
+ return fmt.Errorf("cannot read from file %s: %v", file, err)
+ }
+ }
+
+ return nil
+}
+
+// fileReader encapsulates a fs.File and provides an io.Reader interface.
+type fileReader struct {
+ ctx context.Context
+ file *fs.File
+}
+
+// Read implements io.Reader.Read.
+func (f *fileReader) Read(p []byte) (int, error) {
+ n, err := f.file.Readv(f.ctx, usermem.BytesIOSequence(p))
+ return int(n), err
+}
+
+func cat(k *kernel.Kernel, path string, output *os.File) error {
+ ctx := k.SupervisorContext()
+ mns := k.GlobalInit().Leader().MountNamespace()
+ root := mns.Root()
+ defer root.DecRef(ctx)
+
+ remainingTraversals := uint(fs.DefaultTraversalLimit)
+ d, err := mns.FindInode(ctx, root, nil, path, &remainingTraversals)
+ if err != nil {
+ return fmt.Errorf("cannot find file %s: %v", path, err)
+ }
+ defer d.DecRef(ctx)
+
+ file, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
+ if err != nil {
+ return fmt.Errorf("cannot get file for path %s: %v", path, err)
+ }
+ defer file.DecRef(ctx)
+
+ _, err = io.Copy(output, &fileReader{ctx: ctx, file: file})
+ return err
+}
diff --git a/pkg/sentry/control/lifecycle.go b/pkg/sentry/control/lifecycle.go
new file mode 100644
index 000000000..67abf497d
--- /dev/null
+++ b/pkg/sentry/control/lifecycle.go
@@ -0,0 +1,36 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// Lifecycle provides functions related to starting and stopping tasks.
+type Lifecycle struct {
+ Kernel *kernel.Kernel
+}
+
+// Pause pauses all tasks, blocking until they are stopped.
+func (l *Lifecycle) Pause(_, _ *struct{}) error {
+ l.Kernel.Pause()
+ return nil
+}
+
+// Resume resumes all tasks.
+func (l *Lifecycle) Resume(_, _ *struct{}) error {
+ l.Kernel.Unpause()
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
index 24e28a51f..22c8b7fda 100644
--- a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
+++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
@@ -383,11 +383,6 @@ func (d *dir) DecRef(ctx context.Context) {
d.dirRefs.DecRef(func() { d.Destroy(ctx) })
}
-// StatFS implements kernfs.Inode.StatFS.
-func (d *dir) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) {
- return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil
-}
-
// controllerFile represents a generic control file that appears within a cgroup
// directory.
//
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index ec8d58cc9..25d2e39d6 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -1161,6 +1161,13 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
if !d.isSynthetic() {
if stat.Mask != 0 {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ // d.dataMu must be held around the update to both the remote
+ // file's size and d.size to serialize with writeback (which
+ // might otherwise write data back up to the old d.size after
+ // the remote file has been truncated).
+ d.dataMu.Lock()
+ }
if err := d.file.setAttr(ctx, p9.SetAttrMask{
Permissions: stat.Mask&linux.STATX_MODE != 0,
UID: stat.Mask&linux.STATX_UID != 0,
@@ -1180,13 +1187,16 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
MTimeSeconds: uint64(stat.Mtime.Sec),
MTimeNanoSeconds: uint64(stat.Mtime.Nsec),
}); err != nil {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ d.dataMu.Unlock() // +checklocksforce: locked conditionally above
+ }
return err
}
if stat.Mask&linux.STATX_SIZE != 0 {
// d.size should be kept up to date, and privatized
// copy-on-write mappings of truncated pages need to be
// invalidated, even if InteropModeShared is in effect.
- d.updateSizeLocked(stat.Size)
+ d.updateSizeAndUnlockDataMuLocked(stat.Size) // +checklocksforce: locked conditionally above
}
}
if d.fs.opts.interop == InteropModeShared {
@@ -1249,6 +1259,14 @@ func (d *dentry) doAllocate(ctx context.Context, offset, length uint64, allocate
// Preconditions: d.metadataMu must be locked.
func (d *dentry) updateSizeLocked(newSize uint64) {
d.dataMu.Lock()
+ d.updateSizeAndUnlockDataMuLocked(newSize)
+}
+
+// Preconditions: d.metadataMu and d.dataMu must be locked.
+//
+// Postconditions: d.dataMu is unlocked.
+// +checklocksrelease:d.dataMu
+func (d *dentry) updateSizeAndUnlockDataMuLocked(newSize uint64) {
oldSize := d.size
atomic.StoreUint64(&d.size, newSize)
// d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings
@@ -1257,9 +1275,9 @@ func (d *dentry) updateSizeLocked(newSize uint64) {
// contents beyond the new d.size. (We are still holding d.metadataMu,
// so we can't race with Write or another truncate.)
d.dataMu.Unlock()
- if d.size < oldSize {
+ if newSize < oldSize {
oldpgend, _ := hostarch.PageRoundUp(oldSize)
- newpgend, _ := hostarch.PageRoundUp(d.size)
+ newpgend, _ := hostarch.PageRoundUp(newSize)
if oldpgend != newpgend {
d.mapsMu.Lock()
d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
@@ -1275,8 +1293,8 @@ func (d *dentry) updateSizeLocked(newSize uint64) {
// truncated pages have been removed from the remote file, they
// should be dropped without being written back.
d.dataMu.Lock()
- d.cache.Truncate(d.size, d.fs.mfp.MemoryFile())
- d.dirty.KeepClean(memmap.MappableRange{d.size, oldpgend})
+ d.cache.Truncate(newSize, d.fs.mfp.MemoryFile())
+ d.dirty.KeepClean(memmap.MappableRange{newSize, oldpgend})
d.dataMu.Unlock()
}
}
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index cbbc0935a..f54811edf 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -78,7 +78,7 @@ func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns
"smaps": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &smapsData{task: task}),
"stat": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}),
"statm": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &statmData{task: task}),
- "status": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &statusData{task: task, pidns: pidns}),
+ "status": fs.newStatusInode(ctx, task, pidns, fs.NextIno(), 0444),
"uid_map": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}),
}
if isThreadGroup {
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 5bb6bc372..0ce3ed797 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -661,34 +661,119 @@ func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error {
return nil
}
-// statusData implements vfs.DynamicBytesSource for /proc/[pid]/status.
+// statusInode implements kernfs.Inode for /proc/[pid]/status.
//
// +stateify savable
-type statusData struct {
- kernfs.DynamicBytesFile
+type statusInode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoStatFS
+ kernfs.InodeNoopRefCount
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
task *kernel.Task
pidns *kernel.PIDNamespace
+ locks vfs.FileLocks
}
-var _ dynamicInode = (*statusData)(nil)
+// statusFD implements vfs.FileDescriptionImpl and vfs.DynamicByteSource for
+// /proc/[pid]/status.
+//
+// +stateify savable
+type statusFD struct {
+ statusFDLowerBase
+ vfs.DynamicBytesFileDescriptionImpl
+ vfs.LockFD
+
+ vfsfd vfs.FileDescription
+
+ inode *statusInode
+ task *kernel.Task
+ pidns *kernel.PIDNamespace
+ userns *auth.UserNamespace // equivalent to struct file::f_cred::user_ns
+}
+
+// statusFDLowerBase is a dumb hack to ensure that statusFD prefers
+// vfs.DynamicBytesFileDescriptionImpl methods to vfs.FileDescriptinDefaultImpl
+// methods.
+//
+// +stateify savable
+type statusFDLowerBase struct {
+ vfs.FileDescriptionDefaultImpl
+}
+
+func (fs *filesystem) newStatusInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, ino uint64, perm linux.FileMode) kernfs.Inode {
+ // Note: credentials are overridden by taskOwnedInode.
+ inode := &statusInode{
+ task: task,
+ pidns: pidns,
+ }
+ inode.InodeAttrs.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeRegular|perm)
+ return &taskOwnedInode{Inode: inode, owner: task}
+}
+
+// Open implements kernfs.Inode.Open.
+func (s *statusInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &statusFD{
+ inode: s,
+ task: s.task,
+ pidns: s.pidns,
+ userns: rp.Credentials().UserNamespace,
+ }
+ fd.LockFD.Init(&s.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ fd.SetDataSource(fd)
+ return &fd.vfsfd, nil
+}
+
+// SetStat implements kernfs.Inode.SetStat.
+func (*statusInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ return linuxerr.EPERM
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (s *statusFD) Release(ctx context.Context) {
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (s *statusFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := s.vfsfd.VirtualDentry().Mount().Filesystem()
+ return s.inode.Stat(ctx, fs, opts)
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (s *statusFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ return linuxerr.EPERM
+}
// Generate implements vfs.DynamicBytesSource.Generate.
-func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+func (s *statusFD) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "Name:\t%s\n", s.task.Name())
fmt.Fprintf(buf, "State:\t%s\n", s.task.StateStatus())
fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.task.ThreadGroup()))
fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.task))
+
ppid := kernel.ThreadID(0)
if parent := s.task.Parent(); parent != nil {
ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
}
fmt.Fprintf(buf, "PPid:\t%d\n", ppid)
+
tpid := kernel.ThreadID(0)
if tracer := s.task.Tracer(); tracer != nil {
tpid = s.pidns.IDOfTask(tracer)
}
fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid)
+
+ creds := s.task.Credentials()
+ ruid := creds.RealKUID.In(s.userns).OrOverflow()
+ euid := creds.EffectiveKUID.In(s.userns).OrOverflow()
+ suid := creds.SavedKUID.In(s.userns).OrOverflow()
+ rgid := creds.RealKGID.In(s.userns).OrOverflow()
+ egid := creds.EffectiveKGID.In(s.userns).OrOverflow()
+ sgid := creds.SavedKGID.In(s.userns).OrOverflow()
var fds int
var vss, rss, data uint64
s.task.WithMuLocked(func(t *kernel.Task) {
@@ -701,12 +786,26 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
data = mm.VirtualDataSize()
}
})
+ // Filesystem user/group IDs aren't implemented; effective UID/GID are used
+ // instead.
+ fmt.Fprintf(buf, "Uid:\t%d\t%d\t%d\t%d\n", ruid, euid, suid, euid)
+ fmt.Fprintf(buf, "Gid:\t%d\t%d\t%d\t%d\n", rgid, egid, sgid, egid)
fmt.Fprintf(buf, "FDSize:\t%d\n", fds)
+ buf.WriteString("Groups:\t ")
+ // There is a space between each pair of supplemental GIDs, as well as an
+ // unconditional trailing space that some applications actually depend on.
+ var sep string
+ for _, kgid := range creds.ExtraKGIDs {
+ fmt.Fprintf(buf, "%s%d", sep, kgid.In(s.userns).OrOverflow())
+ sep = " "
+ }
+ buf.WriteString(" \n")
+
fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10)
fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10)
fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10)
+
fmt.Fprintf(buf, "Threads:\t%d\n", s.task.ThreadGroup().Count())
- creds := s.task.Credentials()
fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps)
fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps)
fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
index 80dda1559..b121fc1b4 100644
--- a/pkg/sentry/inet/inet.go
+++ b/pkg/sentry/inet/inet.go
@@ -27,6 +27,9 @@ type Stack interface {
// integers.
Interfaces() map[int32]Interface
+ // RemoveInterface removes the specified network interface.
+ RemoveInterface(idx int32) error
+
// InterfaceAddrs returns all network interface addresses as a mapping from
// interface indexes to a slice of associated interface address properties.
InterfaceAddrs() map[int32][]InterfaceAddr
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
index 218d9dafc..621f47e1f 100644
--- a/pkg/sentry/inet/test_stack.go
+++ b/pkg/sentry/inet/test_stack.go
@@ -45,23 +45,29 @@ func NewTestStack() *TestStack {
}
}
-// Interfaces implements Stack.Interfaces.
+// Interfaces implements Stack.
func (s *TestStack) Interfaces() map[int32]Interface {
return s.InterfacesMap
}
-// InterfaceAddrs implements Stack.InterfaceAddrs.
+// RemoveInterface implements Stack.
+func (s *TestStack) RemoveInterface(idx int32) error {
+ delete(s.InterfacesMap, idx)
+ return nil
+}
+
+// InterfaceAddrs implements Stack.
func (s *TestStack) InterfaceAddrs() map[int32][]InterfaceAddr {
return s.InterfaceAddrsMap
}
-// AddInterfaceAddr implements Stack.AddInterfaceAddr.
+// AddInterfaceAddr implements Stack.
func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error {
s.InterfaceAddrsMap[idx] = append(s.InterfaceAddrsMap[idx], addr)
return nil
}
-// RemoveInterfaceAddr implements Stack.RemoveInterfaceAddr.
+// RemoveInterfaceAddr implements Stack.
func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error {
interfaceAddrs, ok := s.InterfaceAddrsMap[idx]
if !ok {
@@ -79,94 +85,94 @@ func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error {
return nil
}
-// SupportsIPv6 implements Stack.SupportsIPv6.
+// SupportsIPv6 implements Stack.
func (s *TestStack) SupportsIPv6() bool {
return s.SupportsIPv6Flag
}
-// TCPReceiveBufferSize implements Stack.TCPReceiveBufferSize.
+// TCPReceiveBufferSize implements Stack.
func (s *TestStack) TCPReceiveBufferSize() (TCPBufferSize, error) {
return s.TCPRecvBufSize, nil
}
-// SetTCPReceiveBufferSize implements Stack.SetTCPReceiveBufferSize.
+// SetTCPReceiveBufferSize implements Stack.
func (s *TestStack) SetTCPReceiveBufferSize(size TCPBufferSize) error {
s.TCPRecvBufSize = size
return nil
}
-// TCPSendBufferSize implements Stack.TCPSendBufferSize.
+// TCPSendBufferSize implements Stack.
func (s *TestStack) TCPSendBufferSize() (TCPBufferSize, error) {
return s.TCPSendBufSize, nil
}
-// SetTCPSendBufferSize implements Stack.SetTCPSendBufferSize.
+// SetTCPSendBufferSize implements Stack.
func (s *TestStack) SetTCPSendBufferSize(size TCPBufferSize) error {
s.TCPSendBufSize = size
return nil
}
-// TCPSACKEnabled implements Stack.TCPSACKEnabled.
+// TCPSACKEnabled implements Stack.
func (s *TestStack) TCPSACKEnabled() (bool, error) {
return s.TCPSACKFlag, nil
}
-// SetTCPSACKEnabled implements Stack.SetTCPSACKEnabled.
+// SetTCPSACKEnabled implements Stack.
func (s *TestStack) SetTCPSACKEnabled(enabled bool) error {
s.TCPSACKFlag = enabled
return nil
}
-// TCPRecovery implements Stack.TCPRecovery.
+// TCPRecovery implements Stack.
func (s *TestStack) TCPRecovery() (TCPLossRecovery, error) {
return s.Recovery, nil
}
-// SetTCPRecovery implements Stack.SetTCPRecovery.
+// SetTCPRecovery implements Stack.
func (s *TestStack) SetTCPRecovery(recovery TCPLossRecovery) error {
s.Recovery = recovery
return nil
}
-// Statistics implements inet.Stack.Statistics.
+// Statistics implements Stack.
func (s *TestStack) Statistics(stat interface{}, arg string) error {
return nil
}
-// RouteTable implements Stack.RouteTable.
+// RouteTable implements Stack.
func (s *TestStack) RouteTable() []Route {
return s.RouteList
}
-// Resume implements Stack.Resume.
+// Resume implements Stack.
func (s *TestStack) Resume() {}
-// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+// RegisteredEndpoints implements Stack.
func (s *TestStack) RegisteredEndpoints() []stack.TransportEndpoint {
return nil
}
-// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+// CleanupEndpoints implements Stack.
func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint {
return nil
}
-// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+// RestoreCleanupEndpoints implements Stack.
func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
-// SetForwarding implements inet.Stack.SetForwarding.
+// SetForwarding implements Stack.
func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error {
s.IPForwarding = enable
return nil
}
-// PortRange implements inet.Stack.PortRange.
+// PortRange implements Stack.
func (*TestStack) PortRange() (uint16, uint16) {
// Use the default Linux values per net/ipv4/af_inet.c:inet_init_net().
return 32768, 28232
}
-// SetPortRange implements inet.Stack.SetPortRange.
+// SetPortRange implements Stack.
func (*TestStack) SetPortRange(start uint16, end uint16) error {
// No-op.
return nil
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index c613f4932..e4e0dc04f 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -220,6 +220,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/abi/linux/errno",
"//pkg/amutex",
+ "//pkg/bitmap",
"//pkg/bits",
"//pkg/bpf",
"//pkg/cleanup",
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 8786a70b5..eff556a0c 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -18,10 +18,10 @@ import (
"fmt"
"math"
"strings"
- "sync/atomic"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bitmap"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -84,13 +84,8 @@ type FDTable struct {
// mu protects below.
mu sync.Mutex `state:"nosave"`
- // next is start position to find fd.
- next int32
-
- // used contains the number of non-nil entries. It must be accessed
- // atomically. It may be read atomically without holding mu (but not
- // written).
- used int32
+ // fdBitmap shows which fds are already in use.
+ fdBitmap bitmap.Bitmap `state:"nosave"`
// descriptorTable holds descriptors.
descriptorTable `state:".(map[int32]descriptor)"`
@@ -98,6 +93,8 @@ type FDTable struct {
func (f *FDTable) saveDescriptorTable() map[int32]descriptor {
m := make(map[int32]descriptor)
+ f.mu.Lock()
+ defer f.mu.Unlock()
f.forEach(context.Background(), func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
m[fd] = descriptor{
file: file,
@@ -111,12 +108,16 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor {
func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) {
ctx := context.Background()
f.initNoLeakCheck() // Initialize table.
- f.used = 0
+ f.fdBitmap = bitmap.New(uint32(math.MaxUint16))
for fd, d := range m {
+ if fd < 0 {
+ panic(fmt.Sprintf("FD is not supposed to be negative. FD: %d", fd))
+ }
+
if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil {
panic("VFS1 or VFS2 files set")
}
-
+ f.fdBitmap.Add(uint32(fd))
// Note that we do _not_ need to acquire a extra table reference here. The
// table reference will already be accounted for in the file, so we drop the
// reference taken by set above.
@@ -189,8 +190,10 @@ func (f *FDTable) DecRef(ctx context.Context) {
func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags)) {
// retries tracks the number of failed TryIncRef attempts for the same FD.
retries := 0
- fd := int32(0)
- for {
+ fds := f.fdBitmap.ToSlice()
+ // Iterate through the fdBitmap.
+ for _, ufd := range fds {
+ fd := int32(ufd)
file, fileVFS2, flags, ok := f.getAll(fd)
if !ok {
break
@@ -218,7 +221,6 @@ func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File,
fileVFS2.DecRef(ctx)
}
retries = 0
- fd++
}
}
@@ -226,6 +228,8 @@ func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File,
func (f *FDTable) String() string {
var buf strings.Builder
ctx := context.Background()
+ f.mu.Lock()
+ defer f.mu.Unlock()
f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
switch {
case file != nil:
@@ -250,10 +254,10 @@ func (f *FDTable) String() string {
}
// NewFDs allocates new FDs guaranteed to be the lowest number available
-// greater than or equal to the fd parameter. All files will share the set
+// greater than or equal to the minFD parameter. All files will share the set
// flags. Success is guaranteed to be all or none.
-func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags FDFlags) (fds []int32, err error) {
- if fd < 0 {
+func (f *FDTable) NewFDs(ctx context.Context, minFD int32, files []*fs.File, flags FDFlags) (fds []int32, err error) {
+ if minFD < 0 {
// Don't accept negative FDs.
return nil, unix.EINVAL
}
@@ -267,31 +271,48 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
if lim.Cur != limits.Infinity {
end = int32(lim.Cur)
}
- if fd >= end {
+ if minFD+int32(len(files)) > end {
return nil, unix.EMFILE
}
}
f.mu.Lock()
- // From f.next to find available fd.
- if fd < f.next {
- fd = f.next
+ // max is used as the largest number in fdBitmap + 1.
+ max := int32(0)
+
+ if !f.fdBitmap.IsEmpty() {
+ max = int32(f.fdBitmap.Maximum())
+ max++
}
+ // Adjust max in case it is less than minFD.
+ if max < minFD {
+ max = minFD
+ }
// Install all entries.
- for i := fd; i < end && len(fds) < len(files); i++ {
- if d, _, _ := f.get(i); d == nil {
- // Set the descriptor.
- f.set(ctx, i, files[len(fds)], flags)
- fds = append(fds, i) // Record the file descriptor.
+ for len(fds) < len(files) {
+ // Try to use free bit in fdBitmap.
+ // If all bits in fdBitmap are used, expand fd to the max.
+ fd := f.fdBitmap.FirstZero(uint32(minFD))
+ if fd == math.MaxInt32 {
+ fd = uint32(max)
+ max++
+ }
+ if fd >= uint32(end) {
+ break
}
+ f.fdBitmap.Add(fd)
+ f.set(ctx, int32(fd), files[len(fds)], flags)
+ fds = append(fds, int32(fd))
+ minFD = int32(fd)
}
// Failure? Unwind existing FDs.
if len(fds) < len(files) {
for _, i := range fds {
f.set(ctx, i, nil, FDFlags{})
+ f.fdBitmap.Remove(uint32(i))
}
f.mu.Unlock()
@@ -305,20 +326,15 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
return nil, unix.EMFILE
}
- if fd == f.next {
- // Update next search start position.
- f.next = fds[len(fds)-1] + 1
- }
-
f.mu.Unlock()
return fds, nil
}
// NewFDsVFS2 allocates new FDs guaranteed to be the lowest number available
-// greater than or equal to the fd parameter. All files will share the set
+// greater than or equal to the minFD parameter. All files will share the set
// flags. Success is guaranteed to be all or none.
-func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDescription, flags FDFlags) (fds []int32, err error) {
- if fd < 0 {
+func (f *FDTable) NewFDsVFS2(ctx context.Context, minFD int32, files []*vfs.FileDescription, flags FDFlags) (fds []int32, err error) {
+ if minFD < 0 {
// Don't accept negative FDs.
return nil, unix.EINVAL
}
@@ -332,31 +348,47 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes
if lim.Cur != limits.Infinity {
end = int32(lim.Cur)
}
- if fd >= end {
+ if minFD >= end {
return nil, unix.EMFILE
}
}
f.mu.Lock()
- // From f.next to find available fd.
- if fd < f.next {
- fd = f.next
+ // max is used as the largest number in fdBitmap + 1.
+ max := int32(0)
+
+ if !f.fdBitmap.IsEmpty() {
+ max = int32(f.fdBitmap.Maximum())
+ max++
}
- // Install all entries.
- for i := fd; i < end && len(fds) < len(files); i++ {
- if d, _, _ := f.getVFS2(i); d == nil {
- // Set the descriptor.
- f.setVFS2(ctx, i, files[len(fds)], flags)
- fds = append(fds, i) // Record the file descriptor.
- }
+ // Adjust max in case it is less than minFD.
+ if max < minFD {
+ max = minFD
}
+ for len(fds) < len(files) {
+ // Try to use free bit in fdBitmap.
+ // If all bits in fdBitmap are used, expand fd to the max.
+ fd := f.fdBitmap.FirstZero(uint32(minFD))
+ if fd == math.MaxInt32 {
+ fd = uint32(max)
+ max++
+ }
+ if fd >= uint32(end) {
+ break
+ }
+ f.fdBitmap.Add(fd)
+ f.setVFS2(ctx, int32(fd), files[len(fds)], flags)
+ fds = append(fds, int32(fd))
+ minFD = int32(fd)
+ }
// Failure? Unwind existing FDs.
if len(fds) < len(files) {
for _, i := range fds {
f.setVFS2(ctx, i, nil, FDFlags{})
+ f.fdBitmap.Remove(uint32(i))
}
f.mu.Unlock()
@@ -370,57 +402,19 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes
return nil, unix.EMFILE
}
- if fd == f.next {
- // Update next search start position.
- f.next = fds[len(fds)-1] + 1
- }
-
f.mu.Unlock()
return fds, nil
}
-// NewFDVFS2 allocates a file descriptor greater than or equal to minfd for
+// NewFDVFS2 allocates a file descriptor greater than or equal to minFD for
// the given file description. If it succeeds, it takes a reference on file.
-func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
- if minfd < 0 {
- // Don't accept negative FDs.
- return -1, unix.EINVAL
- }
-
- // Default limit.
- end := int32(math.MaxInt32)
-
- // Ensure we don't get past the provided limit.
- if limitSet := limits.FromContext(ctx); limitSet != nil {
- lim := limitSet.Get(limits.NumberOfFiles)
- if lim.Cur != limits.Infinity {
- end = int32(lim.Cur)
- }
- if minfd >= end {
- return -1, unix.EMFILE
- }
- }
-
- f.mu.Lock()
- defer f.mu.Unlock()
-
- // From f.next to find available fd.
- fd := minfd
- if fd < f.next {
- fd = f.next
- }
- for fd < end {
- if d, _, _ := f.getVFS2(fd); d == nil {
- f.setVFS2(ctx, fd, file, flags)
- if fd == f.next {
- // Update next search start position.
- f.next = fd + 1
- }
- return fd, nil
- }
- fd++
+func (f *FDTable) NewFDVFS2(ctx context.Context, minFD int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
+ files := []*vfs.FileDescription{file}
+ fileSlice, error := f.NewFDsVFS2(ctx, minFD, files, flags)
+ if error != nil {
+ return -1, error
}
- return -1, unix.EMFILE
+ return fileSlice[0], nil
}
// NewFDAt sets the file reference for the given FD. If there is an active
@@ -469,6 +463,11 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2
defer f.mu.Unlock()
df, dfVFS2 := f.setAll(ctx, fd, file, fileVFS2, flags)
+ // Add fd to fdBitmap.
+ if file != nil || fileVFS2 != nil {
+ f.fdBitmap.Add(uint32(fd))
+ }
+
return df, dfVFS2, nil
}
@@ -573,7 +572,9 @@ func (f *FDTable) GetVFS2(fd int32) (*vfs.FileDescription, FDFlags) {
// Precondition: The caller must be running on the task goroutine, or Task.mu
// must be locked.
func (f *FDTable) GetFDs(ctx context.Context) []int32 {
- fds := make([]int32, 0, int(atomic.LoadInt32(&f.used)))
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ fds := make([]int32, 0, int(f.fdBitmap.GetNumOnes()))
f.forEach(ctx, func(fd int32, _ *fs.File, _ *vfs.FileDescription, _ FDFlags) {
fds = append(fds, fd)
})
@@ -583,13 +584,15 @@ func (f *FDTable) GetFDs(ctx context.Context) []int32 {
// Fork returns an independent FDTable.
func (f *FDTable) Fork(ctx context.Context) *FDTable {
clone := f.k.NewFDTable()
-
+ f.mu.Lock()
+ defer f.mu.Unlock()
f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
// The set function here will acquire an appropriate table
// reference for the clone. We don't need anything else.
if df, dfVFS2 := clone.setAll(ctx, fd, file, fileVFS2, flags); df != nil || dfVFS2 != nil {
panic("VFS1 or VFS2 files set")
}
+ clone.fdBitmap.Add(uint32(fd))
})
return clone
}
@@ -604,11 +607,6 @@ func (f *FDTable) Remove(ctx context.Context, fd int32) (*fs.File, *vfs.FileDesc
f.mu.Lock()
- // Update current available position.
- if fd < f.next {
- f.next = fd
- }
-
orig, orig2, _, _ := f.getAll(fd)
// Add reference for caller.
@@ -621,6 +619,7 @@ func (f *FDTable) Remove(ctx context.Context, fd int32) (*fs.File, *vfs.FileDesc
if orig != nil || orig2 != nil {
orig, orig2 = f.setAll(ctx, fd, nil, nil, FDFlags{}) // Zap entry.
+ f.fdBitmap.Remove(uint32(fd))
}
f.mu.Unlock()
@@ -644,16 +643,13 @@ func (f *FDTable) RemoveIf(ctx context.Context, cond func(*fs.File, *vfs.FileDes
f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
if cond(file, fileVFS2, flags) {
df, dfVFS2 := f.setAll(ctx, fd, nil, nil, FDFlags{}) // Clear from table.
+ f.fdBitmap.Remove(uint32(fd))
if df != nil {
files = append(files, df)
}
if dfVFS2 != nil {
filesVFS2 = append(filesVFS2, dfVFS2)
}
- // Update current available position.
- if fd < f.next {
- f.next = fd
- }
}
})
f.mu.Unlock()
diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go
index f17f9c59c..2b3e6ef71 100644
--- a/pkg/sentry/kernel/fd_table_unsafe.go
+++ b/pkg/sentry/kernel/fd_table_unsafe.go
@@ -15,9 +15,11 @@
package kernel
import (
+ "math"
"sync/atomic"
"unsafe"
+ "gvisor.dev/gvisor/pkg/bitmap"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -44,6 +46,7 @@ func (f *FDTable) initNoLeakCheck() {
func (f *FDTable) init() {
f.initNoLeakCheck()
f.InitRefs()
+ f.fdBitmap = bitmap.New(uint32(math.MaxUint16))
}
// get gets a file entry.
@@ -162,14 +165,6 @@ func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2
}
}
- // Adjust used.
- switch {
- case orig == nil && desc != nil:
- atomic.AddInt32(&f.used, 1)
- case orig != nil && desc == nil:
- atomic.AddInt32(&f.used, -1)
- }
-
if orig != nil {
switch {
case orig.file != nil:
diff --git a/pkg/sentry/kernel/msgqueue/msgqueue.go b/pkg/sentry/kernel/msgqueue/msgqueue.go
index 3ce926950..c111297d7 100644
--- a/pkg/sentry/kernel/msgqueue/msgqueue.go
+++ b/pkg/sentry/kernel/msgqueue/msgqueue.go
@@ -119,14 +119,21 @@ type Queue struct {
type Message struct {
msgEntry
- // mType is an integer representing the type of the sent message.
- mType int64
+ // Type is an integer representing the type of the sent message.
+ Type int64
- // mText is an untyped block of memory.
- mText []byte
+ // Text is an untyped block of memory.
+ Text []byte
- // mSize is the size of mText.
- mSize uint64
+ // Size is the size of Text.
+ Size uint64
+}
+
+// Blocker is used for blocking Queue.Send, and Queue.Receive calls that serves
+// as an abstracted version of kernel.Task. kernel.Task is not directly used to
+// prevent circular dependencies.
+type Blocker interface {
+ Block(C <-chan struct{}) error
}
// FindOrCreate creates a new message queue or returns an existing one. See
@@ -186,6 +193,265 @@ func (r *Registry) Remove(id ipc.ID, creds *auth.Credentials) error {
return nil
}
+// FindByID returns the queue with the specified ID and an error if the ID
+// doesn't exist.
+func (r *Registry) FindByID(id ipc.ID) (*Queue, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ mech := r.reg.FindByID(id)
+ if mech == nil {
+ return nil, linuxerr.EINVAL
+ }
+ return mech.(*Queue), nil
+}
+
+// Send appends a message to the message queue, and returns an error if sending
+// fails. See msgsnd(2).
+func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid int32) (err error) {
+ // Try to perform a non-blocking send using queue.append. If EWOULDBLOCK
+ // is returned, start the blocking procedure. Otherwise, return normally.
+ creds := auth.CredentialsFromContext(ctx)
+ if err := q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK {
+ return err
+ }
+
+ if !wait {
+ return linuxerr.EAGAIN
+ }
+
+ e, ch := waiter.NewChannelEntry(nil)
+ q.senders.EventRegister(&e, waiter.EventOut)
+
+ for {
+ if err = q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK {
+ break
+ }
+ b.Block(ch)
+ }
+
+ q.senders.EventUnregister(&e)
+ return err
+}
+
+// append appends a message to the queue's message list and notifies waiting
+// receivers that a message has been inserted. It returns an error if adding
+// the message would cause the queue to exceed its maximum capacity, which can
+// be used as a signal to block the task. Other errors should be returned as is.
+func (q *Queue) append(ctx context.Context, m Message, creds *auth.Credentials, pid int32) error {
+ if m.Type <= 0 {
+ return linuxerr.EINVAL
+ }
+
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ if !q.obj.CheckPermissions(creds, fs.PermMask{Write: true}) {
+ // The calling process does not have write permission on the message
+ // queue, and does not have the CAP_IPC_OWNER capability in the user
+ // namespace that governs its IPC namespace.
+ return linuxerr.EACCES
+ }
+
+ // Queue was removed while the process was waiting.
+ if q.dead {
+ return linuxerr.EIDRM
+ }
+
+ // Check if sufficient space is available (the queue isn't full.) From
+ // the man pages:
+ //
+ // "A message queue is considered to be full if either of the following
+ // conditions is true:
+ //
+ // • Adding a new message to the queue would cause the total number
+ // of bytes in the queue to exceed the queue's maximum size (the
+ // msg_qbytes field).
+ //
+ // • Adding another message to the queue would cause the total
+ // number of messages in the queue to exceed the queue's maximum
+ // size (the msg_qbytes field). This check is necessary to
+ // prevent an unlimited number of zero-length messages being
+ // placed on the queue. Although such messages contain no data,
+ // they nevertheless consume (locked) kernel memory."
+ //
+ // The msg_qbytes field in our implementation is q.maxBytes.
+ if m.Size+q.byteCount > q.maxBytes || q.messageCount+1 > q.maxBytes {
+ return linuxerr.EWOULDBLOCK
+ }
+
+ // Copy the message into the queue.
+ q.messages.PushBack(&m)
+
+ q.byteCount += m.Size
+ q.messageCount++
+ q.sendPID = pid
+ q.sendTime = ktime.NowFromContext(ctx)
+
+ // Notify receivers about the new message.
+ q.receivers.Notify(waiter.EventIn)
+
+ return nil
+}
+
+// Receive removes a message from the queue and returns it. See msgrcv(2).
+func (q *Queue) Receive(ctx context.Context, b Blocker, mType int64, maxSize int64, wait, truncate, except bool, pid int32) (msg *Message, err error) {
+ if maxSize < 0 || maxSize > maxMessageBytes {
+ return nil, linuxerr.EINVAL
+ }
+ max := uint64(maxSize)
+
+ // Try to perform a non-blocking receive using queue.pop. If EWOULDBLOCK
+ // is returned, start the blocking procedure. Otherwise, return normally.
+ creds := auth.CredentialsFromContext(ctx)
+ if msg, err := q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK {
+ return msg, err
+ }
+
+ if !wait {
+ return nil, linuxerr.ENOMSG
+ }
+
+ e, ch := waiter.NewChannelEntry(nil)
+ q.receivers.EventRegister(&e, waiter.EventIn)
+
+ for {
+ if msg, err = q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK {
+ break
+ }
+ b.Block(ch)
+ }
+ q.receivers.EventUnregister(&e)
+ return msg, err
+}
+
+// pop pops the first message from the queue that matches the given type. It
+// returns an error for all the cases specified in msgrcv(2). If the queue is
+// empty or no message of the specified type is available, a EWOULDBLOCK error
+// is returned, which can then be used as a signal to block the process or fail.
+func (q *Queue) pop(ctx context.Context, creds *auth.Credentials, mType int64, maxSize uint64, truncate, except bool, pid int32) (msg *Message, _ error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ if !q.obj.CheckPermissions(creds, fs.PermMask{Read: true}) {
+ // The calling process does not have read permission on the message
+ // queue, and does not have the CAP_IPC_OWNER capability in the user
+ // namespace that governs its IPC namespace.
+ return nil, linuxerr.EACCES
+ }
+
+ // Queue was removed while the process was waiting.
+ if q.dead {
+ return nil, linuxerr.EIDRM
+ }
+
+ if q.messages.Empty() {
+ return nil, linuxerr.EWOULDBLOCK
+ }
+
+ // Get a message from the queue.
+ switch {
+ case mType == 0:
+ msg = q.messages.Front()
+ case mType > 0:
+ msg = q.msgOfType(mType, except)
+ case mType < 0:
+ msg = q.msgOfTypeLessThan(-1 * mType)
+ }
+
+ // If no message exists, return a blocking singal.
+ if msg == nil {
+ return nil, linuxerr.EWOULDBLOCK
+ }
+
+ // Check message's size is acceptable.
+ if maxSize < msg.Size {
+ if !truncate {
+ return nil, linuxerr.E2BIG
+ }
+ msg.Size = maxSize
+ msg.Text = msg.Text[:maxSize+1]
+ }
+
+ q.messages.Remove(msg)
+
+ q.byteCount -= msg.Size
+ q.messageCount--
+ q.receivePID = pid
+ q.receiveTime = ktime.NowFromContext(ctx)
+
+ // Notify senders about available space.
+ q.senders.Notify(waiter.EventOut)
+
+ return msg, nil
+}
+
+// Copy copies a message from the queue without deleting it. If no message
+// exists, an error is returned. See msgrcv(MSG_COPY).
+func (q *Queue) Copy(mType int64) (*Message, error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ if mType < 0 || q.messages.Empty() {
+ return nil, linuxerr.ENOMSG
+ }
+
+ msg := q.msgAtIndex(mType)
+ if msg == nil {
+ return nil, linuxerr.ENOMSG
+ }
+ return msg, nil
+}
+
+// msgOfType returns the first message with the specified type, nil if no
+// message is found. If except is true, the first message of a type not equal
+// to mType will be returned.
+//
+// Precondition: caller must hold q.mu.
+func (q *Queue) msgOfType(mType int64, except bool) *Message {
+ if except {
+ for msg := q.messages.Front(); msg != nil; msg = msg.Next() {
+ if msg.Type != mType {
+ return msg
+ }
+ }
+ return nil
+ }
+
+ for msg := q.messages.Front(); msg != nil; msg = msg.Next() {
+ if msg.Type == mType {
+ return msg
+ }
+ }
+ return nil
+}
+
+// msgOfTypeLessThan return the the first message with the lowest type less
+// than or equal to mType, nil if no such message exists.
+//
+// Precondition: caller must hold q.mu.
+func (q *Queue) msgOfTypeLessThan(mType int64) (m *Message) {
+ min := mType
+ for msg := q.messages.Front(); msg != nil; msg = msg.Next() {
+ if msg.Type <= mType && msg.Type < min {
+ m = msg
+ min = msg.Type
+ }
+ }
+ return m
+}
+
+// msgAtIndex returns a pointer to a message at given index, nil if non exits.
+//
+// Precondition: caller must hold q.mu.
+func (q *Queue) msgAtIndex(mType int64) *Message {
+ msg := q.messages.Front()
+ for ; mType != 0 && msg != nil; mType-- {
+ msg = msg.Next()
+ }
+ return msg
+}
+
// Lock implements ipc.Mechanism.Lock.
func (q *Queue) Lock() {
q.mu.Lock()
diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go
index 28a613a54..8fd8287b3 100644
--- a/pkg/sentry/platform/kvm/bluepill_fault.go
+++ b/pkg/sentry/platform/kvm/bluepill_fault.go
@@ -101,7 +101,7 @@ func handleBluepillFault(m *machine, physical uintptr, phyRegions []physicalRegi
// Store the physical address in the slot. This is used to
// avoid calls to handleBluepillFault in the future (see
// machine.mapPhysical).
- atomic.StoreUintptr(&m.usedSlots[slot], physical)
+ atomic.StoreUintptr(&m.usedSlots[slot], physicalStart)
// Successfully added region; we can increment nextSlot and
// allow another set to proceed here.
atomic.StoreUint32(&m.nextSlot, slot+1)
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index f63ab6aba..0f0c1e73b 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-//go:build go1.12 && !go1.18
-// +build go1.12,!go1.18
+//go:build go1.12
+// +build go1.12
-// Check go:linkname function signatures when updating Go version.
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
package kvm
@@ -28,7 +30,7 @@ import (
)
//go:linkname throw runtime.throw
-func throw(string)
+func throw(s string)
// vCPUPtr returns a CPU for the given address.
//
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index 1b5d5f66e..e7092a756 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -70,7 +70,7 @@ type machine struct {
// tscControl checks whether cpu supports TSC scaling
tscControl bool
- // usedSlots is the set of used physical addresses (sorted).
+ // usedSlots is the set of used physical addresses (not sorted).
usedSlots []uintptr
// nextID is the next vCPU ID.
@@ -296,13 +296,20 @@ func newMachine(vm int) (*machine, error) {
return m, nil
}
-// hasSlot returns true iff the given address is mapped.
+// hasSlot returns true if the given address is mapped.
//
// This must be done via a linear scan.
//
//go:nosplit
func (m *machine) hasSlot(physical uintptr) bool {
- for i := 0; i < len(m.usedSlots); i++ {
+ slotLen := int(atomic.LoadUint32(&m.nextSlot))
+ // When slots are being updated, nextSlot is ^uint32(0). As this situation
+ // is less likely happen, we just set the slotLen to m.maxSlots, and scan
+ // the whole usedSlots array.
+ if slotLen == int(^uint32(0)) {
+ slotLen = m.maxSlots
+ }
+ for i := 0; i < slotLen; i++ {
if p := atomic.LoadUintptr(&m.usedSlots[i]); p == physical {
return true
}
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
index 35660e827..cc3a1253b 100644
--- a/pkg/sentry/platform/kvm/machine_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-//go:build go1.12 && !go1.18
-// +build go1.12,!go1.18
+//go:build go1.12
+// +build go1.12
-// Check go:linkname function signatures when updating Go version.
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
package kvm
diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go
index ffd4665f4..304722200 100644
--- a/pkg/sentry/platform/ptrace/subprocess_unsafe.go
+++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go
@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-//go:build go1.12 && !go1.18
-// +build go1.12,!go1.18
+//go:build go1.12
+// +build go1.12
-// Check go:linkname function signatures when updating Go version.
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
package ptrace
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index 7a4e78a5f..61111ac6c 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -309,6 +309,11 @@ func (s *Stack) Interfaces() map[int32]inet.Interface {
return interfaces
}
+// RemoveInterface implements inet.Stack.RemoveInterface.
+func (*Stack) RemoveInterface(int32) error {
+ return linuxerr.EACCES
+}
+
// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
addrs := make(map[int32][]inet.InterfaceAddr)
@@ -319,12 +324,12 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
}
// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
-func (s *Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error {
+func (*Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error {
return linuxerr.EACCES
}
// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr.
-func (s *Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error {
+func (*Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error {
return linuxerr.EACCES
}
@@ -339,7 +344,7 @@ func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
}
// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
-func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
+func (*Stack) SetTCPReceiveBufferSize(inet.TCPBufferSize) error {
return linuxerr.EACCES
}
@@ -349,7 +354,7 @@ func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
}
// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
-func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
+func (*Stack) SetTCPSendBufferSize(inet.TCPBufferSize) error {
return linuxerr.EACCES
}
@@ -359,7 +364,7 @@ func (s *Stack) TCPSACKEnabled() (bool, error) {
}
// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
-func (s *Stack) SetTCPSACKEnabled(bool) error {
+func (*Stack) SetTCPSACKEnabled(bool) error {
return linuxerr.EACCES
}
@@ -369,7 +374,7 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) {
}
// SetTCPRecovery implements inet.Stack.SetTCPRecovery.
-func (s *Stack) SetTCPRecovery(inet.TCPLossRecovery) error {
+func (*Stack) SetTCPRecovery(inet.TCPLossRecovery) error {
return linuxerr.EACCES
}
@@ -470,19 +475,19 @@ func (s *Stack) RouteTable() []inet.Route {
}
// Resume implements inet.Stack.Resume.
-func (s *Stack) Resume() {}
+func (*Stack) Resume() {}
// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
-func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil }
+func (*Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil }
// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
-func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil }
+func (*Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil }
// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
-func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
+func (*Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
// SetForwarding implements inet.Stack.SetForwarding.
-func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error {
+func (*Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error {
return linuxerr.EACCES
}
@@ -493,6 +498,6 @@ func (*Stack) PortRange() (uint16, uint16) {
}
// SetPortRange implements inet.Stack.SetPortRange.
-func (*Stack) SetPortRange(start uint16, end uint16) error {
+func (*Stack) SetPortRange(uint16, uint16) error {
return linuxerr.EACCES
}
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index 86f6419dc..d526acb73 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -161,6 +161,47 @@ func (p *Protocol) getLink(ctx context.Context, msg *netlink.Message, ms *netlin
return nil
}
+// delLink handles RTM_DELLINK requests.
+func (p *Protocol) delLink(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network stack.
+ return syserr.ErrProtocolNotSupported
+ }
+
+ var ifinfomsg linux.InterfaceInfoMessage
+ attrs, ok := msg.GetData(&ifinfomsg)
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+ if ifinfomsg.Index == 0 {
+ // The index is unspecified, search by the interface name.
+ ahdr, value, _, ok := attrs.ParseFirst()
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+ switch ahdr.Type {
+ case linux.IFLA_IFNAME:
+ if len(value) < 1 {
+ return syserr.ErrInvalidArgument
+ }
+ ifname := string(value[:len(value)-1])
+ for idx, ifa := range stack.Interfaces() {
+ if ifname == ifa.Name {
+ ifinfomsg.Index = idx
+ break
+ }
+ }
+ default:
+ return syserr.ErrInvalidArgument
+ }
+ if ifinfomsg.Index == 0 {
+ return syserr.ErrNoDevice
+ }
+ }
+ return syserr.FromError(stack.RemoveInterface(ifinfomsg.Index))
+}
+
// addNewLinkMessage appends RTM_NEWLINK message for the given interface into
// the message set.
func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) {
@@ -537,6 +578,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms
switch hdr.Type {
case linux.RTM_GETLINK:
return p.getLink(ctx, msg, ms)
+ case linux.RTM_DELLINK:
+ return p.delLink(ctx, msg, ms)
case linux.RTM_GETROUTE:
return p.dumpRoutes(ctx, msg, ms)
case linux.RTM_NEWADDR:
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 0fd0ad32c..208ab9909 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -71,6 +71,12 @@ func (s *Stack) Interfaces() map[int32]inet.Interface {
return is
}
+// RemoveInterface implements inet.Stack.RemoveInterface.
+func (s *Stack) RemoveInterface(idx int32) error {
+ nic := tcpip.NICID(idx)
+ return syserr.TranslateNetstackError(s.Stack.RemoveNIC(nic)).ToError()
+}
+
// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
nicAddrs := make(map[int32][]inet.InterfaceAddr)
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index ccccce6a9..b5a371d9a 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -86,6 +86,7 @@ go_library(
"//pkg/sentry/kernel/eventfd",
"//pkg/sentry/kernel/fasync",
"//pkg/sentry/kernel/ipc",
+ "//pkg/sentry/kernel/msgqueue",
"//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/sched",
"//pkg/sentry/kernel/shm",
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 6f44d767b..1ead3c7e8 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -122,8 +122,8 @@ var AMD64 = &kernel.SyscallTable{
66: syscalls.Supported("semctl", Semctl),
67: syscalls.Supported("shmdt", Shmdt),
68: syscalls.Supported("msgget", Msgget),
- 69: syscalls.ErrorWithEvent("msgsnd", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 70: syscalls.ErrorWithEvent("msgrcv", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 69: syscalls.Supported("msgsnd", Msgsnd),
+ 70: syscalls.Supported("msgrcv", Msgrcv),
71: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}),
72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil),
73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil),
@@ -618,8 +618,8 @@ var ARM64 = &kernel.SyscallTable{
185: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
186: syscalls.Supported("msgget", Msgget),
187: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}),
- 188: syscalls.ErrorWithEvent("msgrcv", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 189: syscalls.ErrorWithEvent("msgsnd", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 188: syscalls.Supported("msgrcv", Msgrcv),
+ 189: syscalls.Supported("msgsnd", Msgsnd),
190: syscalls.Supported("semget", Semget),
191: syscalls.Supported("semctl", Semctl),
192: syscalls.Supported("semtimedop", Semtimedop),
diff --git a/pkg/sentry/syscalls/linux/sys_msgqueue.go b/pkg/sentry/syscalls/linux/sys_msgqueue.go
index 3476e218d..5259ade90 100644
--- a/pkg/sentry/syscalls/linux/sys_msgqueue.go
+++ b/pkg/sentry/syscalls/linux/sys_msgqueue.go
@@ -17,10 +17,12 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/ipc"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/msgqueue"
)
// Msgget implements msgget(2).
@@ -41,6 +43,89 @@ func Msgget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return uintptr(queue.ID()), nil, nil
}
+// Msgsnd implements msgsnd(2).
+func Msgsnd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := ipc.ID(args[0].Int())
+ msgAddr := args[1].Pointer()
+ size := args[2].Int64()
+ flag := args[3].Int()
+
+ if size < 0 || size > linux.MSGMAX {
+ return 0, nil, linuxerr.EINVAL
+ }
+
+ wait := flag&linux.IPC_NOWAIT != linux.IPC_NOWAIT
+ pid := int32(t.ThreadGroup().ID())
+
+ buf := linux.MsgBuf{
+ Text: make([]byte, size),
+ }
+ if _, err := buf.CopyIn(t, msgAddr); err != nil {
+ return 0, nil, err
+ }
+
+ queue, err := t.IPCNamespace().MsgqueueRegistry().FindByID(id)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ msg := msgqueue.Message{
+ Type: int64(buf.Type),
+ Text: buf.Text,
+ Size: uint64(size),
+ }
+ return 0, nil, queue.Send(t, msg, t, wait, pid)
+}
+
+// Msgrcv implements msgrcv(2).
+func Msgrcv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := ipc.ID(args[0].Int())
+ msgAddr := args[1].Pointer()
+ size := args[2].Int64()
+ mType := args[3].Int64()
+ flag := args[4].Int()
+
+ wait := flag&linux.IPC_NOWAIT != linux.IPC_NOWAIT
+ except := flag&linux.MSG_EXCEPT == linux.MSG_EXCEPT
+ truncate := flag&linux.MSG_NOERROR == linux.MSG_NOERROR
+
+ msgCopy := flag&linux.MSG_COPY == linux.MSG_COPY
+
+ msg, err := receive(t, id, mType, size, msgCopy, wait, truncate, except)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ buf := linux.MsgBuf{
+ Type: primitive.Int64(msg.Type),
+ Text: msg.Text,
+ }
+ if _, err := buf.CopyOut(t, msgAddr); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(msg.Size), nil, nil
+}
+
+// receive returns a message from the queue with the given ID. If msgCopy is
+// true, a message is copied from the queue without being removed. Otherwise,
+// a message is removed from the queue and returned.
+func receive(t *kernel.Task, id ipc.ID, mType int64, maxSize int64, msgCopy, wait, truncate, except bool) (*msgqueue.Message, error) {
+ pid := int32(t.ThreadGroup().ID())
+
+ queue, err := t.IPCNamespace().MsgqueueRegistry().FindByID(id)
+ if err != nil {
+ return nil, err
+ }
+
+ if msgCopy {
+ if wait || except {
+ return nil, linuxerr.EINVAL
+ }
+ return queue.Copy(mType)
+ }
+ return queue.Receive(t, t, mType, maxSize, wait, truncate, except, pid)
+}
+
// Msgctl implements msgctl(2).
func Msgctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
id := ipc.ID(args[0].Int())
diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go
index a16b6b4d6..2ef1e6404 100644
--- a/pkg/sentry/syscalls/linux/sys_prctl.go
+++ b/pkg/sentry/syscalls/linux/sys_prctl.go
@@ -219,6 +219,21 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
return 0, nil, t.DropBoundingCapability(cp)
+ case linux.PR_SET_CHILD_SUBREAPER:
+ // "If arg2 is nonzero, set the "child subreaper" attribute of
+ // the calling process; if arg2 is zero, unset the attribute."
+ //
+ // TODO(gvisor.dev/issues/2323): We only support setting, and
+ // only if the task is already TID 1 in the PID namespace,
+ // because it already acts as a subreaper in that case.
+ isPid1 := t.PIDNamespace().IDOfTask(t) == kernel.InitTID
+ if args[1].Int() != 0 && isPid1 {
+ return 0, nil, nil
+ }
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, linuxerr.EINVAL
+
case linux.PR_GET_TIMING,
linux.PR_SET_TIMING,
linux.PR_GET_TSC,
@@ -230,7 +245,6 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
linux.PR_MCE_KILL,
linux.PR_MCE_KILL_GET,
linux.PR_GET_TID_ADDRESS,
- linux.PR_SET_CHILD_SUBREAPER,
linux.PR_GET_CHILD_SUBREAPER,
linux.PR_GET_THP_DISABLE,
linux.PR_SET_THP_DISABLE,
diff --git a/pkg/sync/goyield_unsafe.go b/pkg/sync/goyield_unsafe.go
index 8639bb64e..757edbaba 100644
--- a/pkg/sync/goyield_unsafe.go
+++ b/pkg/sync/goyield_unsafe.go
@@ -3,10 +3,12 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build go1.14 && !go1.18
-// +build go1.14,!go1.18
+//go:build go1.14
+// +build go1.14
-// Check go:linkname function signatures when updating Go version.
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
package sync
diff --git a/pkg/sync/runtime_unsafe.go b/pkg/sync/runtime_unsafe.go
index 1d9cf304e..49d4109a9 100644
--- a/pkg/sync/runtime_unsafe.go
+++ b/pkg/sync/runtime_unsafe.go
@@ -6,8 +6,13 @@
//go:build go1.13 && !go1.18
// +build go1.13,!go1.18
-// Check go:linkname function signatures, type definitions, and constants when
-// updating Go version.
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
+
+// Check type definitions and constants when updating Go version.
+//
+// TODO(b/165820485): add these checks to checklinkname.
package sync
@@ -109,10 +114,10 @@ type maptype struct {
// These functions are only used within the sync package.
//go:linkname semacquire sync.runtime_Semacquire
-func semacquire(s *uint32)
+func semacquire(addr *uint32)
//go:linkname semrelease sync.runtime_Semrelease
-func semrelease(s *uint32, handoff bool, skipframes int)
+func semrelease(addr *uint32, handoff bool, skipframes int)
//go:linkname canSpin sync.runtime_canSpin
func canSpin(i int) bool
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index e8e716db0..48356c343 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -56,6 +56,7 @@ import (
// linkDispatcher reads packets from the link FD and dispatches them to the
// NetworkDispatcher.
type linkDispatcher interface {
+ stop()
dispatch() (bool, tcpip.Error)
}
@@ -381,16 +382,27 @@ func isSocketFD(fd int) (bool, error) {
// Attach launches the goroutine that reads packets from the file descriptor and
// dispatches them via the provided dispatcher.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
- e.dispatcher = dispatcher
- // Link endpoints are not savable. When transportation endpoints are
- // saved, they stop sending outgoing packets and all incoming packets
- // are rejected.
- for i := range e.inboundDispatchers {
- e.wg.Add(1)
- go func(i int) { // S/R-SAFE: See above.
- e.dispatchLoop(e.inboundDispatchers[i])
- e.wg.Done()
- }(i)
+ // nil means the NIC is being removed.
+ if dispatcher == nil && e.dispatcher != nil {
+ for _, dispatcher := range e.inboundDispatchers {
+ dispatcher.stop()
+ }
+ e.Wait()
+ e.dispatcher = nil
+ return
+ }
+ if dispatcher != nil && e.dispatcher == nil {
+ e.dispatcher = dispatcher
+ // Link endpoints are not savable. When transportation endpoints are
+ // saved, they stop sending outgoing packets and all incoming packets
+ // are rejected.
+ for i := range e.inboundDispatchers {
+ e.wg.Add(1)
+ go func(i int) { // S/R-SAFE: See above.
+ e.dispatchLoop(e.inboundDispatchers[i])
+ e.wg.Done()
+ }(i)
+ }
}
}
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index bfae34ab9..3f516cab5 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -114,6 +114,7 @@ func (t tPacketHdr) Payload() []byte {
// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
// See: mmap_amd64_unsafe.go for implementation details.
type packetMMapDispatcher struct {
+ stopFd
// fd is the file descriptor used to send and receive packets.
fd int
@@ -129,18 +130,18 @@ type packetMMapDispatcher struct {
ringOffset int
}
-func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) {
+func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, bool, tcpip.Error) {
hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
for hdr.tpStatus()&tpStatusUser == 0 {
- event := rawfile.PollEvent{
- FD: int32(d.fd),
- Events: unix.POLLIN | unix.POLLERR,
- }
- if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 {
+ stopped, errno := rawfile.BlockingPollUntilStopped(d.efd, d.fd, unix.POLLIN|unix.POLLERR)
+ if errno != 0 {
if errno == unix.EINTR {
continue
}
- return nil, rawfile.TranslateErrno(errno)
+ return nil, stopped, rawfile.TranslateErrno(errno)
+ }
+ if stopped {
+ return nil, true, nil
}
if hdr.tpStatus()&tpStatusCopy != 0 {
// This frame is truncated so skip it after flipping the
@@ -158,14 +159,14 @@ func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) {
// Release packet to kernel.
hdr.setTPStatus(tpStatusKernel)
d.ringOffset = (d.ringOffset + 1) % tpFrameNR
- return pkt, nil
+ return pkt, false, nil
}
// dispatch reads packets from an mmaped ring buffer and dispatches them to the
// network stack.
func (d *packetMMapDispatcher) dispatch() (bool, tcpip.Error) {
- pkt, err := d.readMMappedPacket()
- if err != nil {
+ pkt, stopped, err := d.readMMappedPacket()
+ if err != nil || stopped {
return false, err
}
var (
diff --git a/pkg/tcpip/link/fdbased/mmap_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go
index 58d5dfeef..5b786169a 100644
--- a/pkg/tcpip/link/fdbased/mmap_unsafe.go
+++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go
@@ -47,9 +47,14 @@ func (t tPacketHdr) setTPStatus(status uint32) {
}
func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ stopFd, err := newStopFd()
+ if err != nil {
+ return nil, err
+ }
d := &packetMMapDispatcher{
- fd: fd,
- e: e,
+ stopFd: stopFd,
+ fd: fd,
+ e: e,
}
pageSize := unix.Getpagesize()
if tpBlockSize%pageSize != 0 {
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index ab2855a63..fab34c5fa 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -18,6 +18,8 @@
package fdbased
import (
+ "fmt"
+
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -114,9 +116,36 @@ func (b *iovecBuffer) pullViews(n int) buffer.VectorisedView {
return buffer.NewVectorisedView(n, views)
}
+// stopFd is an eventfd used to signal the stop of a dispatcher.
+type stopFd struct {
+ efd int
+}
+
+func newStopFd() (stopFd, error) {
+ efd, err := unix.Eventfd(0, unix.EFD_NONBLOCK)
+ if err != nil {
+ return stopFd{efd: -1}, fmt.Errorf("failed to create eventfd: %w", err)
+ }
+ return stopFd{efd: efd}, nil
+}
+
+// stop writes to the eventfd and notifies the dispatcher to stop. It does not
+// block.
+func (s *stopFd) stop() {
+ increment := []byte{1, 0, 0, 0, 0, 0, 0, 0}
+ if n, err := unix.Write(s.efd, increment); n != len(increment) || err != nil {
+ // There are two possible errors documented in eventfd(2) for writing:
+ // 1. We are writing 8 bytes and not 0xffffffffffffff, thus no EINVAL.
+ // 2. stop is only supposed to be called once, it can't reach the limit,
+ // thus no EAGAIN.
+ panic(fmt.Sprintf("write(efd) = (%d, %s), want (%d, nil)", n, err, len(increment)))
+ }
+}
+
// readVDispatcher uses readv() system call to read inbound packets and
// dispatches them.
type readVDispatcher struct {
+ stopFd
// fd is the file descriptor used to send and receive packets.
fd int
@@ -128,7 +157,15 @@ type readVDispatcher struct {
}
func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
- d := &readVDispatcher{fd: fd, e: e}
+ stopFd, err := newStopFd()
+ if err != nil {
+ return nil, err
+ }
+ d := &readVDispatcher{
+ stopFd: stopFd,
+ fd: fd,
+ e: e,
+ }
skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported
d.buf = newIovecBuffer(BufConfig, skipsVnetHdr)
return d, nil
@@ -136,8 +173,8 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
// dispatch reads one packet from the file descriptor and dispatches it.
func (d *readVDispatcher) dispatch() (bool, tcpip.Error) {
- n, err := rawfile.BlockingReadv(d.fd, d.buf.nextIovecs())
- if n == 0 || err != nil {
+ n, err := rawfile.BlockingReadvUntilStopped(d.efd, d.fd, d.buf.nextIovecs())
+ if n <= 0 || err != nil {
return false, err
}
@@ -184,6 +221,7 @@ func (d *readVDispatcher) dispatch() (bool, tcpip.Error) {
// recvMMsgDispatcher uses the recvmmsg system call to read inbound packets and
// dispatches them.
type recvMMsgDispatcher struct {
+ stopFd
// fd is the file descriptor used to send and receive packets.
fd int
@@ -207,7 +245,12 @@ const (
)
func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ stopFd, err := newStopFd()
+ if err != nil {
+ return nil, err
+ }
d := &recvMMsgDispatcher{
+ stopFd: stopFd,
fd: fd,
e: e,
bufs: make([]*iovecBuffer, MaxMsgsPerRecv),
@@ -235,8 +278,8 @@ func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) {
d.msgHdrs[k].Msg.SetIovlen(iovLen)
}
- nMsgs, err := rawfile.BlockingRecvMMsg(d.fd, d.msgHdrs)
- if err != nil {
+ nMsgs, err := rawfile.BlockingRecvMMsgUntilStopped(d.efd, d.fd, d.msgHdrs)
+ if nMsgs == -1 || err != nil {
return false, err
}
// Process each of received packets.
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index b1a28491d..40bd5560b 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -115,6 +115,13 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc
// Attach implements stack.LinkEndpoint.Attach.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ // nil means the NIC is being removed.
+ if dispatcher == nil {
+ e.lower.Attach(nil)
+ e.Wait()
+ e.dispatcher = nil
+ return
+ }
e.dispatcher = dispatcher
e.lower.Attach(e)
}
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
index da900c24b..0b7b9e3de 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-//go:build ((linux && amd64) || (linux && arm64)) && go1.12 && !go1.18
+//go:build ((linux && amd64) || (linux && arm64)) && go1.12
// +build linux,amd64 linux,arm64
// +build go1.12
-// +build !go1.18
-// Check go:linkname function signatures when updating Go version.
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
package rawfile
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index 53448a641..e76fc55b6 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -170,46 +170,63 @@ func BlockingRead(fd int, b []byte) (int, tcpip.Error) {
}
}
-// BlockingReadv reads from a file descriptor that is set up as non-blocking and
-// stores the data in a list of iovecs buffers. If no data is available, it will
-// block in a poll() syscall until the file descriptor becomes readable.
-func BlockingReadv(fd int, iovecs []unix.Iovec) (int, tcpip.Error) {
+// BlockingReadvUntilStopped reads from a file descriptor that is set up as
+// non-blocking and stores the data in a list of iovecs buffers. If no data is
+// available, it will block in a poll() syscall until the file descriptor
+// becomes readable or stop is signalled (efd becomes readable). Returns -1 in
+// the latter case.
+func BlockingReadvUntilStopped(efd int, fd int, iovecs []unix.Iovec) (int, tcpip.Error) {
for {
n, _, e := unix.RawSyscall(unix.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs)))
if e == 0 {
return int(n), nil
}
- event := PollEvent{
- FD: int32(fd),
- Events: 1, // POLLIN
+ stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
+ if stopped {
+ return -1, nil
}
-
- _, e = BlockingPoll(&event, 1, nil)
if e != 0 && e != unix.EINTR {
return 0, TranslateErrno(e)
}
}
}
-// BlockingRecvMMsg reads from a file descriptor that is set up as non-blocking
-// and stores the received messages in a slice of MMsgHdr structures. If no data
-// is available, it will block in a poll() syscall until the file descriptor
-// becomes readable.
-func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) {
+// BlockingRecvMMsgUntilStopped reads from a file descriptor that is set up as
+// non-blocking and stores the received messages in a slice of MMsgHdr
+// structures. If no data is available, it will block in a poll() syscall until
+// the file descriptor becomes readable or stop is signalled (efd becomes
+// readable). Returns -1 in the latter case.
+func BlockingRecvMMsgUntilStopped(efd int, fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) {
for {
n, _, e := unix.RawSyscall6(unix.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), unix.MSG_DONTWAIT, 0, 0)
if e == 0 {
return int(n), nil
}
- event := PollEvent{
- FD: int32(fd),
- Events: 1, // POLLIN
+ stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
+ if stopped {
+ return -1, nil
}
-
- if _, e := BlockingPoll(&event, 1, nil); e != 0 && e != unix.EINTR {
+ if e != 0 && e != unix.EINTR {
return 0, TranslateErrno(e)
}
}
}
+
+// BlockingPollUntilStopped polls for events on fd or until a stop is signalled
+// on the event fd efd. Returns true if stopped, i.e., efd has event POLLIN.
+func BlockingPollUntilStopped(efd int, fd int, events int16) (bool, unix.Errno) {
+ pevents := [...]PollEvent{
+ {
+ FD: int32(efd),
+ Events: unix.POLLIN,
+ },
+ {
+ FD: int32(fd),
+ Events: events,
+ },
+ }
+ _, errno := BlockingPoll(&pevents[0], len(pevents), nil)
+ return pevents[0].Revents&unix.POLLIN != 0, errno
+}
diff --git a/pkg/tcpip/link/sniffer/pcap.go b/pkg/tcpip/link/sniffer/pcap.go
index 3bb864ed2..d3edede63 100644
--- a/pkg/tcpip/link/sniffer/pcap.go
+++ b/pkg/tcpip/link/sniffer/pcap.go
@@ -14,7 +14,14 @@
package sniffer
-import "time"
+import (
+ "encoding"
+ "encoding/binary"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
type pcapHeader struct {
// MagicNumber is the file magic number.
@@ -39,25 +46,38 @@ type pcapHeader struct {
Network uint32
}
-type pcapPacketHeader struct {
- // Seconds is the timestamp seconds.
- Seconds uint32
-
- // Microseconds is the timestamp microseconds.
- Microseconds uint32
+var _ encoding.BinaryMarshaler = (*pcapPacket)(nil)
- // IncludedLength is the number of octets of packet saved in file.
- IncludedLength uint32
-
- // OriginalLength is the actual length of packet.
- OriginalLength uint32
+type pcapPacket struct {
+ timestamp time.Time
+ packet *stack.PacketBuffer
+ maxCaptureLen int
}
-func newPCAPPacketHeader(now time.Time, incLen, orgLen uint32) pcapPacketHeader {
- return pcapPacketHeader{
- Seconds: uint32(now.Unix()),
- Microseconds: uint32(now.Nanosecond() / 1000),
- IncludedLength: incLen,
- OriginalLength: orgLen,
+func (p *pcapPacket) MarshalBinary() ([]byte, error) {
+ packetSize := p.packet.Size()
+ captureLen := p.maxCaptureLen
+ if packetSize < captureLen {
+ captureLen = packetSize
+ }
+ b := make([]byte, 16+captureLen)
+ binary.BigEndian.PutUint32(b[0:4], uint32(p.timestamp.Unix()))
+ binary.BigEndian.PutUint32(b[4:8], uint32(p.timestamp.Nanosecond()/1000))
+ binary.BigEndian.PutUint32(b[8:12], uint32(captureLen))
+ binary.BigEndian.PutUint32(b[12:16], uint32(packetSize))
+ w := tcpip.SliceWriter(b[16:])
+ for _, v := range p.packet.Views() {
+ if captureLen == 0 {
+ break
+ }
+ if len(v) > captureLen {
+ v = v[:captureLen]
+ }
+ n, err := w.Write(v)
+ if err != nil {
+ panic(err)
+ }
+ captureLen -= n
}
+ return b, nil
}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 3df826f3c..28a172e71 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -151,33 +151,16 @@ func (e *endpoint) dumpPacket(dir direction, protocol tcpip.NetworkProtocolNumbe
logPacket(e.logPrefix, dir, protocol, pkt)
}
if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
- totalLength := pkt.Size()
- length := totalLength
- if max := int(e.maxPCAPLen); length > max {
- length = max
+ packet := pcapPacket{
+ timestamp: time.Now(),
+ packet: pkt,
+ maxCaptureLen: int(e.maxPCAPLen),
}
- packetHeader := newPCAPPacketHeader(time.Now(), uint32(length), uint32(totalLength))
- packet := make([]byte, binary.Size(packetHeader)+length)
- {
- writer := tcpip.SliceWriter(packet)
- if err := binary.Write(&writer, binary.BigEndian, packetHeader); err != nil {
- panic(err)
- }
- for _, b := range pkt.Views() {
- if length == 0 {
- break
- }
- if len(b) > length {
- b = b[:length]
- }
- n, err := writer.Write(b)
- if err != nil {
- panic(err)
- }
- length -= n
- }
+ b, err := packet.MarshalBinary()
+ if err != nil {
+ panic(err)
}
- if _, err := writer.Write(packet); err != nil {
+ if _, err := writer.Write(b); err != nil {
panic(err)
}
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index c90974693..2257f728e 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -39,7 +39,6 @@ go_test(
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/arp",
"//pkg/tcpip/network/internal/testutil",
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 4a4448cf9..73407be67 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -32,7 +32,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
@@ -3339,7 +3338,7 @@ func TestCloseLocking(t *testing.T) {
defer wg.Done()
for i := 0; i < iterations; i++ {
- if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
+ if err := s.CreateNIC(nicID2, stack.LinkEndpoint(channel.New(0, defaultMTU, ""))); err != nil {
t.Errorf("CreateNIC(%d, _): %s", nicID2, err)
return
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 81fabe29a..c73890c4c 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -780,6 +780,9 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error {
if !ok {
return &tcpip.ErrUnknownNICID{}
}
+ if nic.IsLoopback() {
+ return &tcpip.ErrNotSupported{}
+ }
delete(s.nics, id)
// Remove routes in-place. n tracks the number of routes written.
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 21951d05a..3089c0ef4 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -719,38 +719,59 @@ func TestRemoveUnknownNIC(t *testing.T) {
}
func TestRemoveNIC(t *testing.T) {
- const nicID = 1
+ for _, tt := range []struct {
+ name string
+ linkep stack.LinkEndpoint
+ expectErr tcpip.Error
+ }{
+ {
+ name: "loopback",
+ linkep: loopback.New(),
+ expectErr: &tcpip.ErrNotSupported{},
+ },
+ {
+ name: "channel",
+ linkep: channel.New(0, defaultMTU, ""),
+ expectErr: nil,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ })
- e := linkEPWithMockedAttach{
- LinkEndpoint: loopback.New(),
- }
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
+ e := linkEPWithMockedAttach{
+ LinkEndpoint: tt.linkep,
+ }
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- // NIC should be present in NICInfo and attached to a NetworkDispatcher.
- allNICInfo := s.NICInfo()
- if _, ok := allNICInfo[nicID]; !ok {
- t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
- }
- if !e.isAttached() {
- t.Fatal("link endpoint not attached to a network dispatcher")
- }
+ // NIC should be present in NICInfo and attached to a NetworkDispatcher.
+ allNICInfo := s.NICInfo()
+ if _, ok := allNICInfo[nicID]; !ok {
+ t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
+ }
+ if !e.isAttached() {
+ t.Fatal("link endpoint not attached to a network dispatcher")
+ }
- // Removing a NIC should remove it from NICInfo and e should be detached from
- // the NetworkDispatcher.
- if err := s.RemoveNIC(nicID); err != nil {
- t.Fatalf("s.RemoveNIC(%d): %s", nicID, err)
- }
- if nicInfo, ok := s.NICInfo()[nicID]; ok {
- t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo)
- }
- if e.isAttached() {
- t.Error("link endpoint for removed NIC still attached to a network dispatcher")
+ // Removing a NIC should remove it from NICInfo and e should be detached from
+ // the NetworkDispatcher.
+ if got, want := s.RemoveNIC(nicID), tt.expectErr; got != want {
+ t.Fatalf("got s.RemoveNIC(%d) = %s, want %s", nicID, got, want)
+ }
+ if tt.expectErr == nil {
+ if nicInfo, ok := s.NICInfo()[nicID]; ok {
+ t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo)
+ }
+ if e.isAttached() {
+ t.Error("link endpoint for removed NIC still attached to a network dispatcher")
+ }
+ }
+ })
}
}
diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go
index e5b0ec3ae..60b532798 100644
--- a/runsc/boot/controller.go
+++ b/runsc/boot/controller.go
@@ -57,20 +57,12 @@ const (
// ContMgrExecuteAsync executes a command in a container.
ContMgrExecuteAsync = "containerManager.ExecuteAsync"
- // ContMgrPause pauses the sandbox (note that individual containers cannot be
- // paused).
- ContMgrPause = "containerManager.Pause"
-
// ContMgrProcesses lists processes running in a container.
ContMgrProcesses = "containerManager.Processes"
// ContMgrRestore restores a container from a statefile.
ContMgrRestore = "containerManager.Restore"
- // ContMgrResume unpauses the paused sandbox (note that individual containers
- // cannot be resumed).
- ContMgrResume = "containerManager.Resume"
-
// ContMgrSignal sends a signal to a container.
ContMgrSignal = "containerManager.Signal"
@@ -111,6 +103,17 @@ const (
LoggingChange = "Logging.Change"
)
+// Lifecycle related commands (see lifecycle.go for more details).
+const (
+ LifecyclePause = "Lifecycle.Pause"
+ LifecycleResume = "Lifecycle.Resume"
+)
+
+// Filesystem related commands (see fs.go for more details).
+const (
+ FsCat = "Fs.Cat"
+)
+
// ControlSocketAddr generates an abstract unix socket name for the given ID.
func ControlSocketAddr(id string) string {
return fmt.Sprintf("\x00runsc-sandbox.%s", id)
@@ -152,6 +155,8 @@ func newController(fd int, l *Loader) (*controller, error) {
ctrl.srv.Register(&debug{})
ctrl.srv.Register(&control.Logging{})
+ ctrl.srv.Register(&control.Lifecycle{l.k})
+ ctrl.srv.Register(&control.Fs{l.k})
if l.root.conf.ProfileEnable {
ctrl.srv.Register(control.NewProfile(l.k))
@@ -340,17 +345,6 @@ func (cm *containerManager) Checkpoint(o *control.SaveOpts, _ *struct{}) error {
return state.Save(o, nil)
}
-// Pause suspends a sandbox.
-func (cm *containerManager) Pause(_, _ *struct{}) error {
- log.Debugf("containerManager.Pause")
- // TODO(gvisor.dev/issues/6243): save/restore not supported w/ hostinet
- if cm.l.root.conf.Network == config.NetworkHost {
- return errors.New("pause not supported when using hostinet")
- }
- cm.l.k.Pause()
- return nil
-}
-
// RestoreOpts contains options related to restoring a container's file system.
type RestoreOpts struct {
// FilePayload contains the state file to be restored, followed by the
@@ -482,13 +476,6 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error {
return nil
}
-// Resume unpauses a sandbox.
-func (cm *containerManager) Resume(_, _ *struct{}) error {
- log.Debugf("containerManager.Resume")
- cm.l.k.Unpause()
- return nil
-}
-
// Wait waits for the init process in the given container.
func (cm *containerManager) Wait(cid *string, waitStatus *uint32) error {
log.Debugf("containerManager.Wait, cid: %s", *cid)
diff --git a/runsc/cmd/chroot.go b/runsc/cmd/chroot.go
index 7b11b3367..1fe9c6435 100644
--- a/runsc/cmd/chroot.go
+++ b/runsc/cmd/chroot.go
@@ -59,6 +59,23 @@ func pivotRoot(root string) error {
return nil
}
+func copyFile(dst, src string) error {
+ in, err := os.Open(src)
+ if err != nil {
+ return err
+ }
+ defer in.Close()
+
+ out, err := os.Create(dst)
+ if err != nil {
+ return err
+ }
+ defer out.Close()
+
+ _, err = out.ReadFrom(in)
+ return err
+}
+
// setUpChroot creates an empty directory with runsc mounted at /runsc and proc
// mounted at /proc.
func setUpChroot(pidns bool) error {
@@ -78,6 +95,14 @@ func setUpChroot(pidns bool) error {
return fmt.Errorf("error mounting tmpfs in choot: %v", err)
}
+ if err := os.Mkdir(filepath.Join(chroot, "etc"), 0755); err != nil {
+ return fmt.Errorf("error creating /etc in chroot: %v", err)
+ }
+
+ if err := copyFile(filepath.Join(chroot, "etc/localtime"), "/etc/localtime"); err != nil {
+ log.Warningf("Failed to copy /etc/localtime: %v. UTC timezone will be used.", err)
+ }
+
if pidns {
flags := uint32(unix.MS_NOSUID | unix.MS_NODEV | unix.MS_NOEXEC | unix.MS_RDONLY)
if err := mountInChroot(chroot, "proc", "/proc", "proc", flags); err != nil {
diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go
index da81cf048..f773ccca0 100644
--- a/runsc/cmd/debug.go
+++ b/runsc/cmd/debug.go
@@ -48,6 +48,7 @@ type Debug struct {
delay time.Duration
duration time.Duration
ps bool
+ cat stringSlice
}
// Name implements subcommands.Command.
@@ -81,6 +82,7 @@ func (d *Debug) SetFlags(f *flag.FlagSet) {
f.StringVar(&d.logLevel, "log-level", "", "The log level to set: warning (0), info (1), or debug (2).")
f.StringVar(&d.logPackets, "log-packets", "", "A boolean value to enable or disable packet logging: true or false.")
f.BoolVar(&d.ps, "ps", false, "lists processes")
+ f.Var(&d.cat, "cat", "reads files and print to standard output")
}
// Execute implements subcommands.Command.Execute.
@@ -367,5 +369,11 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{})
return subcommands.ExitFailure
}
+ if d.cat != nil {
+ if err := c.Cat(d.cat, os.Stdout); err != nil {
+ return Errorf("Cat failed: %v", err)
+ }
+ }
+
return subcommands.ExitSuccess
}
diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go
index 20e05f141..2193e9040 100644
--- a/runsc/cmd/gofer.go
+++ b/runsc/cmd/gofer.go
@@ -285,16 +285,22 @@ func setupRootFS(spec *specs.Spec, conf *config.Config) error {
// Prepare tree structure for pivot_root(2).
if err := os.Mkdir("/proc/proc", 0755); err != nil {
- Fatalf("%v", err)
+ Fatalf("error creating /proc/proc: %v", err)
}
if err := os.Mkdir("/proc/root", 0755); err != nil {
- Fatalf("%v", err)
+ Fatalf("error creating /proc/root: %v", err)
+ }
+ if err := os.Mkdir("/proc/etc", 0755); err != nil {
+ Fatalf("error creating /proc/etc: %v", err)
}
// This cannot use SafeMount because there's no available procfs. But we
// know that /proc is an empty tmpfs mount, so this is safe.
if err := unix.Mount("runsc-proc", "/proc/proc", "proc", flags|unix.MS_RDONLY, ""); err != nil {
Fatalf("error mounting proc: %v", err)
}
+ if err := copyFile("/proc/etc/localtime", "/etc/localtime"); err != nil {
+ log.Warningf("Failed to copy /etc/localtime: %v. UTC timezone will be used.", err)
+ }
root = "/proc/root"
procPath = "/proc/proc"
}
@@ -409,7 +415,7 @@ func resolveMounts(conf *config.Config, mounts []specs.Mount, root string) ([]sp
panic(fmt.Sprintf("%q could not be made relative to %q: %v", dst, root, err))
}
- opts, err := adjustMountOptions(filepath.Join(root, relDst), m.Options)
+ opts, err := adjustMountOptions(conf, filepath.Join(root, relDst), m.Options)
if err != nil {
return nil, err
}
@@ -475,7 +481,7 @@ func resolveSymlinksImpl(root, base, rel string, followCount uint) (string, erro
}
// adjustMountOptions adds 'overlayfs_stale_read' if mounting over overlayfs.
-func adjustMountOptions(path string, opts []string) ([]string, error) {
+func adjustMountOptions(conf *config.Config, path string, opts []string) ([]string, error) {
rv := make([]string, len(opts))
copy(rv, opts)
diff --git a/runsc/container/container.go b/runsc/container/container.go
index 6a9a07afe..d1f979eb2 100644
--- a/runsc/container/container.go
+++ b/runsc/container/container.go
@@ -646,6 +646,12 @@ func (c *Container) Resume() error {
return c.saveLocked()
}
+// Cat prints out the content of the files.
+func (c *Container) Cat(files []string, out *os.File) error {
+ log.Debugf("Cat in container, cid: %s, files: %+v", c.ID, files)
+ return c.Sandbox.Cat(c.ID, files, out)
+}
+
// State returns the metadata of the container.
func (c *Container) State() specs.State {
return specs.State{
diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go
index 5fb4a3672..960c36946 100644
--- a/runsc/container/container_test.go
+++ b/runsc/container/container_test.go
@@ -442,6 +442,11 @@ func configs(t *testing.T, opts ...configOption) map[string]*config.Config {
return all
}
+// sleepSpec generates a spec with sleep 1000 and a conf.
+func sleepSpecConf(t *testing.T) (*specs.Spec, *config.Config) {
+ return testutil.NewSpecWithArgs("sleep", "1000"), testutil.TestConfig(t)
+}
+
// TestLifecycle tests the basic Create/Start/Signal/Destroy container lifecycle.
// It verifies after each step that the container can be loaded from disk, and
// has the correct status.
@@ -455,7 +460,7 @@ func TestLifecycle(t *testing.T) {
t.Run(name, func(t *testing.T) {
// The container will just sleep for a long time. We will kill it before
// it finishes sleeping.
- spec := testutil.NewSpecWithArgs("sleep", "100")
+ spec, _ := sleepSpecConf(t)
rootDir, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
@@ -903,7 +908,7 @@ func TestExecProcList(t *testing.T) {
for name, conf := range configs(t, all...) {
t.Run(name, func(t *testing.T) {
const uid = 343
- spec := testutil.NewSpecWithArgs("sleep", "100")
+ spec, _ := sleepSpecConf(t)
_, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
@@ -1422,8 +1427,7 @@ func TestPauseResume(t *testing.T) {
// with calls to pause and resume and that pausing and resuming only
// occurs given the correct state.
func TestPauseResumeStatus(t *testing.T) {
- spec := testutil.NewSpecWithArgs("sleep", "20")
- conf := testutil.TestConfig(t)
+ spec, conf := sleepSpecConf(t)
_, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
@@ -1490,7 +1494,7 @@ func TestCapabilities(t *testing.T) {
for name, conf := range configs(t, all...) {
t.Run(name, func(t *testing.T) {
- spec := testutil.NewSpecWithArgs("sleep", "100")
+ spec, _ := sleepSpecConf(t)
rootDir, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
@@ -1640,7 +1644,7 @@ func TestMountNewDir(t *testing.T) {
func TestReadonlyRoot(t *testing.T) {
for name, conf := range configs(t, all...) {
t.Run(name, func(t *testing.T) {
- spec := testutil.NewSpecWithArgs("sleep", "100")
+ spec, _ := sleepSpecConf(t)
spec.Root.Readonly = true
_, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
@@ -1692,7 +1696,7 @@ func TestReadonlyMount(t *testing.T) {
if err != nil {
t.Fatalf("ioutil.TempDir() failed: %v", err)
}
- spec := testutil.NewSpecWithArgs("sleep", "100")
+ spec, _ := sleepSpecConf(t)
spec.Mounts = append(spec.Mounts, specs.Mount{
Destination: dir,
Source: dir,
@@ -1852,7 +1856,7 @@ func doAbbreviatedIDsTest(t *testing.T, vfs2 bool) {
"baz-" + testutil.RandomContainerID(),
}
for _, cid := range cids {
- spec := testutil.NewSpecWithArgs("sleep", "100")
+ spec, _ := sleepSpecConf(t)
bundleDir, cleanup, err := testutil.SetupBundleDir(spec)
if err != nil {
t.Fatalf("error setting up container: %v", err)
@@ -2229,7 +2233,7 @@ func TestMountPropagation(t *testing.T) {
t.Fatalf("mount(%q, MS_SHARED): %v", srcMnt, err)
}
- spec := testutil.NewSpecWithArgs("sleep", "1000")
+ spec, conf := sleepSpecConf(t)
priv := filepath.Join(tmpDir, "priv")
slave := filepath.Join(tmpDir, "slave")
@@ -2248,7 +2252,6 @@ func TestMountPropagation(t *testing.T) {
},
}
- conf := testutil.TestConfig(t)
_, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
@@ -2563,12 +2566,11 @@ func TestRlimits(t *testing.T) {
// TestRlimitsExec sets limit to number of open files and checks that the limit
// is propagated to exec'd processes.
func TestRlimitsExec(t *testing.T) {
- spec := testutil.NewSpecWithArgs("sleep", "100")
+ spec, conf := sleepSpecConf(t)
spec.Process.Rlimits = []specs.POSIXRlimit{
{Type: "RLIMIT_NOFILE", Hard: 1000, Soft: 100},
}
- conf := testutil.TestConfig(t)
_, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
if err != nil {
t.Fatalf("error setting up container: %v", err)
@@ -2597,3 +2599,59 @@ func TestRlimitsExec(t *testing.T) {
t.Errorf("ulimit result, got: %q, want: %q", got, want)
}
}
+
+// TestCat creates a file and checks that cat generates the expected output.
+func TestCat(t *testing.T) {
+ f, err := ioutil.TempFile(testutil.TmpDir(), "test-case")
+ if err != nil {
+ t.Fatalf("ioutil.TempFile failed: %v", err)
+ }
+ defer os.RemoveAll(f.Name())
+
+ content := "test-cat"
+ if _, err := f.WriteString(content); err != nil {
+ t.Fatalf("f.WriteString(): %v", err)
+ }
+ f.Close()
+
+ spec, conf := sleepSpecConf(t)
+
+ _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf)
+ if err != nil {
+ t.Fatalf("error setting up container: %v", err)
+ }
+ defer cleanup()
+
+ args := Args{
+ ID: testutil.RandomContainerID(),
+ Spec: spec,
+ BundleDir: bundleDir,
+ }
+
+ cont, err := New(conf, args)
+ if err != nil {
+ t.Fatalf("Creating container: %v", err)
+ }
+ defer cont.Destroy()
+
+ if err := cont.Start(conf); err != nil {
+ t.Fatalf("starting container: %v", err)
+ }
+
+ r, w, err := os.Pipe()
+ if err != nil {
+ t.Fatalf("os.Create(): %v", err)
+ }
+
+ if err := cont.Cat([]string{f.Name()}, w); err != nil {
+ t.Fatalf("error cat from container: %v", err)
+ }
+
+ buf := make([]byte, 1024)
+ if _, err := r.Read(buf); err != nil {
+ t.Fatalf("Read out: %v", err)
+ }
+ if got, want := string(buf), content; !strings.Contains(got, want) {
+ t.Errorf("out got %s, want include %s", buf, want)
+ }
+}
diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go
index 5fb7dc834..b15572a98 100644
--- a/runsc/sandbox/sandbox.go
+++ b/runsc/sandbox/sandbox.go
@@ -981,7 +981,7 @@ func (s *Sandbox) Pause(cid string) error {
}
defer conn.Close()
- if err := conn.Call(boot.ContMgrPause, nil, nil); err != nil {
+ if err := conn.Call(boot.LifecyclePause, nil, nil); err != nil {
return fmt.Errorf("pausing container %q: %v", cid, err)
}
return nil
@@ -996,12 +996,30 @@ func (s *Sandbox) Resume(cid string) error {
}
defer conn.Close()
- if err := conn.Call(boot.ContMgrResume, nil, nil); err != nil {
+ if err := conn.Call(boot.LifecycleResume, nil, nil); err != nil {
return fmt.Errorf("resuming container %q: %v", cid, err)
}
return nil
}
+// Cat sends the cat call for a container in the sandbox.
+func (s *Sandbox) Cat(cid string, files []string, out *os.File) error {
+ log.Debugf("Cat sandbox %q", s.ID)
+ conn, err := s.sandboxConnect()
+ if err != nil {
+ return err
+ }
+ defer conn.Close()
+
+ if err := conn.Call(boot.FsCat, &control.CatOpts{
+ Files: files,
+ FilePayload: urpc.FilePayload{Files: []*os.File{out}},
+ }, nil); err != nil {
+ return fmt.Errorf("Cat container %q: %v", cid, err)
+ }
+ return nil
+}
+
// IsRunning returns true if the sandbox or gofer process is running.
func (s *Sandbox) IsRunning() bool {
if s.Pid != 0 {
diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go
index 9e22c9a7d..d41139944 100644
--- a/test/e2e/integration_test.go
+++ b/test/e2e/integration_test.go
@@ -742,3 +742,49 @@ func TestUnmount(t *testing.T) {
t.Fatalf("docker run failed: %v", err)
}
}
+
+func TestDeleteInterface(t *testing.T) {
+ if testutil.IsRunningWithHostNet() {
+ t.Skip("not able to remove interfaces on hostnet")
+ }
+
+ ctx := context.Background()
+ d := dockerutil.MakeContainer(ctx, t)
+ defer d.CleanUp(ctx)
+
+ opts := dockerutil.RunOpts{
+ Image: "basic/alpine",
+ CapAdd: []string{"NET_ADMIN"},
+ }
+ if err := d.Spawn(ctx, opts, "sleep", "1000"); err != nil {
+ t.Fatalf("docker run failed: %v", err)
+ }
+
+ // We should be able to remove eth0.
+ output, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "ip link del dev eth0")
+ if err != nil {
+ t.Fatalf("failed to remove eth0: %s, output: %s", err, output)
+ }
+ // Verify that eth0 is no longer there.
+ output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "ip link show")
+ if err != nil {
+ t.Fatalf("docker exec ip link show failed: %s, output: %s", err, output)
+ }
+ if strings.Contains(output, "eth0") {
+ t.Fatalf("failed to remove eth0")
+ }
+
+ // Loopback device can't be removed.
+ output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "ip link del dev lo")
+ if err == nil {
+ t.Fatalf("should not remove the loopback device: %v", output)
+ }
+ // Verify that lo is still there.
+ output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "ip link show")
+ if err != nil {
+ t.Fatalf("docker exec ip link show failed: %s, output: %s", err, output)
+ }
+ if !strings.Contains(output, "lo") {
+ t.Fatalf("loopback interface is removed")
+ }
+}
diff --git a/test/perf/BUILD b/test/perf/BUILD
index 97ca0e75a..58fe333ee 100644
--- a/test/perf/BUILD
+++ b/test/perf/BUILD
@@ -70,6 +70,13 @@ syscall_test(
)
syscall_test(
+ size = "large",
+ add_overlay = True,
+ debug = False,
+ test = "//test/perf/linux:dup_benchmark",
+)
+
+syscall_test(
debug = False,
test = "//test/perf/linux:pipe_benchmark",
)
@@ -146,3 +153,24 @@ syscall_test(
test = "//test/perf/linux:verity_open_benchmark",
vfs1 = False,
)
+
+syscall_test(
+ size = "large",
+ debug = False,
+ test = "//test/perf/linux:verity_read_benchmark",
+ vfs1 = False,
+)
+
+syscall_test(
+ size = "large",
+ debug = False,
+ test = "//test/perf/linux:verity_randread_benchmark",
+ vfs1 = False,
+)
+
+syscall_test(
+ size = "large",
+ debug = False,
+ test = "//test/perf/linux:verity_open_read_close_benchmark",
+ vfs1 = False,
+)
diff --git a/test/perf/linux/BUILD b/test/perf/linux/BUILD
index b4f192227..61ed98ff5 100644
--- a/test/perf/linux/BUILD
+++ b/test/perf/linux/BUILD
@@ -109,6 +109,22 @@ cc_binary(
)
cc_binary(
+ name = "dup_benchmark",
+ testonly = 1,
+ srcs = [
+ "dup_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:fs_util",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ ],
+)
+
+cc_binary(
name = "read_benchmark",
testonly = 1,
srcs = [
@@ -389,3 +405,60 @@ cc_binary(
"//test/util:verity_util",
],
)
+
+cc_binary(
+ name = "verity_read_benchmark",
+ testonly = 1,
+ srcs = [
+ "verity_read_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:capability_util",
+ "//test/util:fs_util",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:verity_util",
+ ],
+)
+
+cc_binary(
+ name = "verity_randread_benchmark",
+ testonly = 1,
+ srcs = [
+ "verity_randread_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:capability_util",
+ "//test/util:fs_util",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:verity_util",
+ ],
+)
+
+cc_binary(
+ name = "verity_open_read_close_benchmark",
+ testonly = 1,
+ srcs = [
+ "verity_open_read_close_benchmark.cc",
+ ],
+ deps = [
+ gbenchmark,
+ gtest,
+ "//test/util:capability_util",
+ "//test/util:fs_util",
+ "//test/util:logging",
+ "//test/util:temp_path",
+ "//test/util:test_main",
+ "//test/util:test_util",
+ "//test/util:verity_util",
+ ],
+)
diff --git a/test/perf/linux/dup_benchmark.cc b/test/perf/linux/dup_benchmark.cc
new file mode 100644
index 000000000..5d808d225
--- /dev/null
+++ b/test/perf/linux/dup_benchmark.cc
@@ -0,0 +1,55 @@
+// Copyright 2020 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <unistd.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/fs_util.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+
+namespace gvisor {
+
+namespace {
+
+void BM_Dup(benchmark::State& state) {
+ const int size = state.range(0);
+
+ for (auto _ : state) {
+ std::vector<int> v;
+ for (int i = 0; i < size; i++) {
+ int fd = dup(2);
+ TEST_CHECK(fd != -1);
+ v.push_back(fd);
+ }
+ for (int i = 0; i < size; i++) {
+ int fd = v[i];
+ close(fd);
+ }
+ }
+ state.SetItemsProcessed(state.iterations() * size);
+}
+
+BENCHMARK(BM_Dup)->Range(1, 1 << 15)->UseRealTime();
+
+} // namespace
+
+} // namespace gvisor
diff --git a/test/perf/linux/verity_open_benchmark.cc b/test/perf/linux/verity_open_benchmark.cc
index ce53a2100..026b6f101 100644
--- a/test/perf/linux/verity_open_benchmark.cc
+++ b/test/perf/linux/verity_open_benchmark.cc
@@ -36,8 +36,6 @@ namespace testing {
namespace {
void BM_Open(benchmark::State& state) {
- SKIP_IF(IsRunningWithVFS1());
-
const int size = state.range(0);
std::vector<TempPath> cache;
std::vector<EnableTarget> targets;
diff --git a/test/perf/linux/verity_open_read_close_benchmark.cc b/test/perf/linux/verity_open_read_close_benchmark.cc
new file mode 100644
index 000000000..e77577f22
--- /dev/null
+++ b/test/perf/linux/verity_open_read_close_benchmark.cc
@@ -0,0 +1,75 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/mount.h>
+#include <unistd.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/capability_util.h"
+#include "test/util/fs_util.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/verity_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_VerityOpenReadClose(benchmark::State& state) {
+ const int size = state.range(0);
+
+ // Mount a tmpfs file system to be wrapped by a verity fs.
+ TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TEST_CHECK(mount("", dir.path().c_str(), "tmpfs", 0, "") == 0);
+
+ std::vector<TempPath> cache;
+ std::vector<EnableTarget> targets;
+
+ for (int i = 0; i < size; i++) {
+ auto file = ASSERT_NO_ERRNO_AND_VALUE(
+ TempPath::CreateFileWith(dir.path(), "some contents", 0644));
+ targets.emplace_back(
+ EnableTarget(std::string(Basename(file.path())), O_RDONLY));
+ cache.emplace_back(std::move(file));
+ }
+
+ std::string verity_dir =
+ TEST_CHECK_NO_ERRNO_AND_VALUE(MountVerity(dir.path(), targets));
+
+ char buf[1];
+ unsigned int seed = 1;
+ for (auto _ : state) {
+ const int chosen = rand_r(&seed) % size;
+ int fd = open(JoinPath(verity_dir, targets[chosen].path).c_str(), O_RDONLY);
+ TEST_CHECK(fd != -1);
+ TEST_CHECK(read(fd, buf, 1) == 1);
+ close(fd);
+ }
+}
+
+BENCHMARK(BM_VerityOpenReadClose)->Range(1000, 16384)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/verity_randread_benchmark.cc b/test/perf/linux/verity_randread_benchmark.cc
new file mode 100644
index 000000000..4178cfad8
--- /dev/null
+++ b/test/perf/linux/verity_randread_benchmark.cc
@@ -0,0 +1,108 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/mount.h>
+#include <sys/stat.h>
+#include <sys/uio.h>
+#include <unistd.h>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/verity_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+// Create a 1GB file that will be read from at random positions. This should
+// invalid any performance gains from caching.
+const uint64_t kFileSize = Megabytes(1024);
+
+// How many bytes to write at once to initialize the file used to read from.
+const uint32_t kWriteSize = 65536;
+
+// Largest benchmarked read unit.
+const uint32_t kMaxRead = Megabytes(64);
+
+// Global test state, initialized once per process lifetime.
+struct GlobalState {
+ explicit GlobalState() {
+ // Mount a tmpfs file system to be wrapped by a verity fs.
+ tmp_dir_ = TempPath::CreateDir().ValueOrDie();
+ TEST_CHECK(mount("", tmp_dir_.path().c_str(), "tmpfs", 0, "") == 0);
+ file_ = TempPath::CreateFileIn(tmp_dir_.path()).ValueOrDie();
+ filename_ = std::string(Basename(file_.path()));
+
+ FileDescriptor fd = Open(file_.path(), O_WRONLY).ValueOrDie();
+
+ // Try to minimize syscalls by using maximum size writev() requests.
+ std::vector<char> buffer(kWriteSize);
+ RandomizeBuffer(buffer.data(), buffer.size());
+ const std::vector<std::vector<struct iovec>> iovecs_list =
+ GenerateIovecs(kFileSize + kMaxRead, buffer.data(), buffer.size());
+ for (const auto& iovecs : iovecs_list) {
+ TEST_CHECK(writev(fd.get(), iovecs.data(), iovecs.size()) >= 0);
+ }
+ verity_dir_ =
+ MountVerity(tmp_dir_.path(), {EnableTarget(filename_, O_RDONLY)})
+ .ValueOrDie();
+ }
+ TempPath tmp_dir_;
+ TempPath file_;
+ std::string verity_dir_;
+ std::string filename_;
+};
+
+GlobalState& GetGlobalState() {
+ // This gets created only once throughout the lifetime of the process.
+ // Use a dynamically allocated object (that is never deleted) to avoid order
+ // of destruction of static storage variables issues.
+ static GlobalState* const state =
+ // The actual file size is the maximum random seek range (kFileSize) + the
+ // maximum read size so we can read that number of bytes at the end of the
+ // file.
+ new GlobalState();
+ return *state;
+}
+
+void BM_VerityRandRead(benchmark::State& state) {
+ const int size = state.range(0);
+
+ GlobalState& global_state = GetGlobalState();
+ FileDescriptor verity_fd = ASSERT_NO_ERRNO_AND_VALUE(Open(
+ JoinPath(global_state.verity_dir_, global_state.filename_), O_RDONLY));
+ std::vector<char> buf(size);
+
+ unsigned int seed = 1;
+ for (auto _ : state) {
+ TEST_CHECK(PreadFd(verity_fd.get(), buf.data(), buf.size(),
+ rand_r(&seed) % kFileSize) == size);
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_VerityRandRead)->Range(1, kMaxRead)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/perf/linux/verity_read_benchmark.cc b/test/perf/linux/verity_read_benchmark.cc
new file mode 100644
index 000000000..738b5ba45
--- /dev/null
+++ b/test/perf/linux/verity_read_benchmark.cc
@@ -0,0 +1,69 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+#include <fcntl.h>
+#include <stdlib.h>
+#include <sys/mount.h>
+#include <unistd.h>
+
+#include <memory>
+#include <string>
+#include <vector>
+
+#include "gtest/gtest.h"
+#include "benchmark/benchmark.h"
+#include "test/util/capability_util.h"
+#include "test/util/fs_util.h"
+#include "test/util/logging.h"
+#include "test/util/temp_path.h"
+#include "test/util/test_util.h"
+#include "test/util/verity_util.h"
+
+namespace gvisor {
+namespace testing {
+
+namespace {
+
+void BM_VerityRead(benchmark::State& state) {
+ const int size = state.range(0);
+ const std::string contents(size, 0);
+
+ // Mount a tmpfs file system to be wrapped by a verity fs.
+ TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir());
+ TEST_CHECK(mount("", dir.path().c_str(), "tmpfs", 0, "") == 0);
+
+ auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith(
+ dir.path(), contents, TempPath::kDefaultFileMode));
+ std::string filename = std::string(Basename(path.path()));
+
+ std::string verity_dir = TEST_CHECK_NO_ERRNO_AND_VALUE(
+ MountVerity(dir.path(), {EnableTarget(filename, O_RDONLY)}));
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(Open(JoinPath(verity_dir, filename), O_RDONLY));
+ std::vector<char> buf(size);
+ for (auto _ : state) {
+ TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(), 0) == size);
+ }
+
+ state.SetBytesProcessed(static_cast<int64_t>(size) *
+ static_cast<int64_t>(state.iterations()));
+}
+
+BENCHMARK(BM_VerityRead)->Range(1, 1 << 26)->UseRealTime();
+
+} // namespace
+
+} // namespace testing
+} // namespace gvisor
diff --git a/test/root/chroot_test.go b/test/root/chroot_test.go
index 58fcd6f08..5114a9602 100644
--- a/test/root/chroot_test.go
+++ b/test/root/chroot_test.go
@@ -68,13 +68,15 @@ func TestChroot(t *testing.T) {
if err != nil {
t.Fatalf("error listing %q: %v", chroot, err)
}
- if want, got := 1, len(fi); want != got {
+ if want, got := 2, len(fi); want != got {
t.Fatalf("chroot dir got %d entries, want %d", got, want)
}
- // chroot dir is prepared by runsc and should contains only /proc.
- if fi[0].Name() != "proc" {
- t.Errorf("chroot got children %v, want %v", fi[0].Name(), "proc")
+ // chroot dir is prepared by runsc and should contains only /etc and /proc.
+ for i, want := range []string{"etc", "proc"} {
+ if got := fi[i].Name(); got != want {
+ t.Errorf("chroot got child %v, want %v", got, want)
+ }
}
d.CleanUp(ctx)
diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD
index 7185df076..7129a797b 100644
--- a/test/syscalls/linux/BUILD
+++ b/test/syscalls/linux/BUILD
@@ -560,6 +560,7 @@ cc_binary(
deps = [
"//test/util:eventfd_util",
"//test/util:file_descriptor",
+ "@com_google_absl//absl/memory",
gtest,
"//test/util:fs_util",
"//test/util:posix_error",
@@ -1750,6 +1751,7 @@ cc_binary(
"//test/util:mount_util",
"@com_google_absl//absl/container:node_hash_set",
"@com_google_absl//absl/strings",
+ "@com_google_absl//absl/strings:str_format",
"@com_google_absl//absl/synchronization",
"@com_google_absl//absl/time",
gtest,
@@ -4170,9 +4172,11 @@ cc_binary(
srcs = ["msgqueue.cc"],
linkstatic = 1,
deps = [
+ "//test/util:capability_util",
"//test/util:temp_path",
"//test/util:test_main",
"//test/util:test_util",
+ "@com_google_absl//absl/time",
],
)
diff --git a/test/syscalls/linux/dup.cc b/test/syscalls/linux/dup.cc
index ba4e13fb9..8f0974f45 100644
--- a/test/syscalls/linux/dup.cc
+++ b/test/syscalls/linux/dup.cc
@@ -13,9 +13,11 @@
// limitations under the License.
#include <fcntl.h>
+#include <sys/resource.h>
#include <unistd.h>
#include "gtest/gtest.h"
+#include "absl/memory/memory.h"
#include "test/util/eventfd_util.h"
#include "test/util/file_descriptor.h"
#include "test/util/fs_util.h"
@@ -98,6 +100,45 @@ TEST(DupTest, Dup2) {
ASSERT_NO_ERRNO(CheckSameFile(fd, nfd2));
}
+TEST(DupTest, Rlimit) {
+ auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
+ FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY));
+
+ struct rlimit rl = {};
+ EXPECT_THAT(getrlimit(RLIMIT_NOFILE, &rl), SyscallSucceeds());
+
+ ASSERT_THAT(setrlimit(RLIMIT_NOFILE, &rl), SyscallSucceeds());
+
+ constexpr int kFDLimit = 101;
+ // Create a file descriptor that will be above the limit.
+ FileDescriptor aboveLimitFD =
+ ASSERT_NO_ERRNO_AND_VALUE(Dup2(fd, kFDLimit * 2 - 1));
+
+ rl.rlim_cur = kFDLimit;
+ ASSERT_THAT(setrlimit(RLIMIT_NOFILE, &rl), SyscallSucceeds());
+ ASSERT_THAT(dup3(fd.get(), kFDLimit, 0), SyscallFails());
+
+ std::vector<std::unique_ptr<FileDescriptor>> fds;
+ int prev_fd = fd.get();
+ int used_fds = 0;
+ for (int i = 0; i < kFDLimit; ++i) {
+ int new_fd = dup(fd.get());
+ if (new_fd == -1) {
+ break;
+ }
+ auto f = absl::make_unique<FileDescriptor>(new_fd);
+ EXPECT_LT(new_fd, kFDLimit);
+ EXPECT_GT(new_fd, prev_fd);
+ // Check that all fds in (prev_fd, new_fd) are used.
+ for (int j = prev_fd + 1; j < new_fd; ++j) {
+ if (fcntl(j, F_GETFD) != -1) used_fds++;
+ }
+ prev_fd = new_fd;
+ fds.push_back(std::move(f));
+ }
+ EXPECT_EQ(fds.size() + used_fds, kFDLimit - fd.get() - 1);
+}
+
TEST(DupTest, Dup2SameFD) {
auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY));
diff --git a/test/syscalls/linux/lseek.cc b/test/syscalls/linux/lseek.cc
index d4f89527c..dbc21833f 100644
--- a/test/syscalls/linux/lseek.cc
+++ b/test/syscalls/linux/lseek.cc
@@ -121,7 +121,8 @@ TEST(LseekTest, InvalidFD) {
}
TEST(LseekTest, DirCurEnd) {
- const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open("/tmp", O_RDONLY));
+ const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(
+ Open(GetAbsoluteTestTmpdir().c_str(), O_RDONLY));
ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0));
}
diff --git a/test/syscalls/linux/memfd.cc b/test/syscalls/linux/memfd.cc
index 4a450742b..dbd1c93ae 100644
--- a/test/syscalls/linux/memfd.cc
+++ b/test/syscalls/linux/memfd.cc
@@ -445,9 +445,10 @@ TEST(MemfdTest, SealsAreInodeLevelProperties) {
// Tmpfs files also support seals, but are created with F_SEAL_SEAL.
TEST(MemfdTest, TmpfsFilesHaveSealSeal) {
- SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs("/tmp")));
+ std::string tmpdir = GetAbsoluteTestTmpdir();
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(tmpdir.c_str())));
const TempPath tmpfs_file =
- ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn("/tmp"));
+ ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(tmpdir.c_str()));
const FileDescriptor fd =
ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfs_file.path(), O_RDWR, 0644));
EXPECT_THAT(fcntl(fd.get(), F_GET_SEALS),
diff --git a/test/syscalls/linux/msgqueue.cc b/test/syscalls/linux/msgqueue.cc
index 2409de7e8..837e913d9 100644
--- a/test/syscalls/linux/msgqueue.cc
+++ b/test/syscalls/linux/msgqueue.cc
@@ -12,10 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
+#include <errno.h>
#include <sys/ipc.h>
#include <sys/msg.h>
#include <sys/types.h>
+#include "absl/time/clock.h"
+#include "test/util/capability_util.h"
#include "test/util/temp_path.h"
#include "test/util/test_util.h"
@@ -23,6 +26,10 @@ namespace gvisor {
namespace testing {
namespace {
+constexpr int msgMax = 8192; // Max size for message in bytes.
+constexpr int msgMni = 32000; // Max number of identifiers.
+constexpr int msgMnb = 16384; // Default max size of message queue in bytes.
+
// Queue is a RAII class used to automatically clean message queues.
class Queue {
public:
@@ -46,6 +53,25 @@ class Queue {
int id_ = -1;
};
+// Default size for messages.
+constexpr size_t msgSize = 50;
+
+// msgbuf is a simple buffer using to send and receive text messages for
+// testing purposes.
+struct msgbuf {
+ int64_t mtype;
+ char mtext[msgSize];
+};
+
+bool operator==(msgbuf& a, msgbuf& b) {
+ for (size_t i = 0; i < msgSize; i++) {
+ if (a.mtext[i] != b.mtext[i]) {
+ return false;
+ }
+ }
+ return a.mtype == b.mtype;
+}
+
// Test simple creation and retrieval for msgget(2).
TEST(MsgqueueTest, MsgGet) {
const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile());
@@ -82,6 +108,552 @@ TEST(MsgqueueTest, MsgGetIpcPrivate) {
EXPECT_NE(queue1.get(), queue2.get());
}
+// Test simple msgsnd and msgrcv.
+TEST(MsgqueueTest, MsgOpSimple) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ msgbuf buf{1, "A message."};
+ msgbuf rcv;
+
+ ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0),
+ SyscallSucceeds());
+
+ EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, 0, 0),
+ SyscallSucceedsWithValue(sizeof(buf.mtext)));
+ EXPECT_TRUE(buf == rcv);
+}
+
+// Test msgsnd and msgrcv of an empty message.
+TEST(MsgqueueTest, MsgOpEmpty) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ msgbuf buf{1, ""};
+ msgbuf rcv;
+
+ ASSERT_THAT(msgsnd(queue.get(), &buf, 0, 0), SyscallSucceeds());
+ EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, 0, 0),
+ SyscallSucceedsWithValue(0));
+}
+
+// Test truncation of message with MSG_NOERROR flag.
+TEST(MsgqueueTest, MsgOpTruncate) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ msgbuf buf{1, ""};
+ msgbuf rcv;
+
+ ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0),
+ SyscallSucceeds());
+
+ EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) - 1, 0, MSG_NOERROR),
+ SyscallSucceedsWithValue(sizeof(buf.mtext) - 1));
+}
+
+// Test msgsnd and msgrcv using invalid arguments.
+TEST(MsgqueueTest, MsgOpInvalidArgs) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ msgbuf buf{1, ""};
+
+ EXPECT_THAT(msgsnd(-1, &buf, 0, 0), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(msgsnd(queue.get(), &buf, -1, 0), SyscallFailsWithErrno(EINVAL));
+
+ buf.mtype = -1;
+ EXPECT_THAT(msgsnd(queue.get(), &buf, 1, 0), SyscallFailsWithErrno(EINVAL));
+
+ EXPECT_THAT(msgrcv(-1, &buf, 1, 0, 0), SyscallFailsWithErrno(EINVAL));
+ EXPECT_THAT(msgrcv(queue.get(), &buf, -1, 0, 0),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Test non-blocking msgrcv with an empty queue.
+TEST(MsgqueueTest, MsgOpNoMsg) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ msgbuf rcv;
+ EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(rcv.mtext) + 1, 0, IPC_NOWAIT),
+ SyscallFailsWithErrno(ENOMSG));
+}
+
+// Test non-blocking msgrcv with a non-empty queue, but no messages of wanted
+// type.
+TEST(MsgqueueTest, MsgOpNoMsgType) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ msgbuf buf{1, ""};
+ ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0),
+ SyscallSucceeds());
+
+ EXPECT_THAT(msgrcv(queue.get(), &buf, sizeof(buf.mtext) + 1, 2, IPC_NOWAIT),
+ SyscallFailsWithErrno(ENOMSG));
+}
+
+// Test msgrcv with a larger size message than wanted, and truncation disabled.
+TEST(MsgqueueTest, MsgOpTooBig) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ msgbuf buf{1, ""};
+ ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0),
+ SyscallSucceeds());
+
+ EXPECT_THAT(msgrcv(queue.get(), &buf, sizeof(buf.mtext) - 1, 0, 0),
+ SyscallFailsWithErrno(E2BIG));
+}
+
+// Test receiving messages based on type.
+TEST(MsgqueueTest, MsgRcvType) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ // Send messages in an order and receive them in reverse, based on type,
+ // which shouldn't block.
+ std::map<int64_t, msgbuf> typeToBuf = {
+ {1, msgbuf{1, "Message 1."}}, {2, msgbuf{2, "Message 2."}},
+ {3, msgbuf{3, "Message 3."}}, {4, msgbuf{4, "Message 4."}},
+ {5, msgbuf{5, "Message 5."}}, {6, msgbuf{6, "Message 6."}},
+ {7, msgbuf{7, "Message 7."}}, {8, msgbuf{8, "Message 8."}},
+ {9, msgbuf{9, "Message 9."}}};
+
+ for (auto const& [type, buf] : typeToBuf) {
+ ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0),
+ SyscallSucceeds());
+ }
+
+ for (int64_t i = typeToBuf.size(); i > 0; i--) {
+ msgbuf rcv;
+ EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(typeToBuf[i].mtext) + 1, i, 0),
+ SyscallSucceedsWithValue(sizeof(typeToBuf[i].mtext)));
+ EXPECT_TRUE(typeToBuf[i] == rcv);
+ }
+}
+
+// Test using MSG_EXCEPT to receive a different-type message.
+TEST(MsgqueueTest, MsgExcept) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ std::map<int64_t, msgbuf> typeToBuf = {
+ {1, msgbuf{1, "Message 1."}},
+ {2, msgbuf{2, "Message 2."}},
+ };
+
+ for (auto const& [type, buf] : typeToBuf) {
+ ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0),
+ SyscallSucceeds());
+ }
+
+ for (int64_t i = typeToBuf.size(); i > 0; i--) {
+ msgbuf actual = typeToBuf[i == 1 ? 2 : 1];
+ msgbuf rcv;
+
+ EXPECT_THAT(
+ msgrcv(queue.get(), &rcv, sizeof(actual.mtext) + 1, i, MSG_EXCEPT),
+ SyscallSucceedsWithValue(sizeof(actual.mtext)));
+ EXPECT_TRUE(actual == rcv);
+ }
+}
+
+// Test msgrcv with a negative type.
+TEST(MsgqueueTest, MsgRcvTypeNegative) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ // When msgtyp is negative, msgrcv returns the first message with mtype less
+ // than or equal to the absolute value.
+ msgbuf buf{2, "A message."};
+ msgbuf rcv;
+
+ ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0),
+ SyscallSucceeds());
+
+ // Nothing is less than or equal to 1.
+ EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, -1, IPC_NOWAIT),
+ SyscallFailsWithErrno(ENOMSG));
+
+ EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, -3, 0),
+ SyscallSucceedsWithValue(sizeof(buf.mtext)));
+ EXPECT_TRUE(buf == rcv);
+}
+
+// Test permission-related failure scenarios.
+TEST(MsgqueueTest, MsgOpPermissions) {
+ AutoCapability cap(CAP_IPC_OWNER, false);
+
+ Queue queue(msgget(IPC_PRIVATE, 0000));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ msgbuf buf{1, ""};
+
+ EXPECT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0),
+ SyscallFailsWithErrno(EACCES));
+ EXPECT_THAT(msgrcv(queue.get(), &buf, sizeof(buf.mtext), 0, 0),
+ SyscallFailsWithErrno(EACCES));
+}
+
+// Test limits for messages and queues.
+TEST(MsgqueueTest, MsgOpLimits) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ msgbuf buf{1, "A message."};
+
+ // Limit for one message.
+ EXPECT_THAT(msgsnd(queue.get(), &buf, msgMax + 1, 0),
+ SyscallFailsWithErrno(EINVAL));
+
+ // Limit for queue.
+ // Use a buffer with the maximum mount of bytes that can be transformed to
+ // make it easier to exhaust the queue limit.
+ struct msgmax {
+ int64_t mtype;
+ char mtext[msgMax];
+ };
+
+ msgmax limit{1, ""};
+ for (size_t i = 0, msgCount = msgMnb / msgMax; i < msgCount; i++) {
+ EXPECT_THAT(msgsnd(queue.get(), &limit, sizeof(limit.mtext), 0),
+ SyscallSucceeds());
+ }
+ EXPECT_THAT(msgsnd(queue.get(), &limit, sizeof(limit.mtext), IPC_NOWAIT),
+ SyscallFailsWithErrno(EAGAIN));
+}
+
+// MsgCopySupported returns true if MSG_COPY is supported.
+bool MsgCopySupported() {
+ // msgrcv(2) man page states that MSG_COPY flag is available only if the
+ // kernel was built with the CONFIG_CHECKPOINT_RESTORE option. If MSG_COPY
+ // is used when the kernel was configured without the option, msgrcv produces
+ // a ENOSYS error.
+ // To avoid test failure, we perform a small test using msgrcv, and skip the
+ // test if errno == ENOSYS. This means that the test will always run on
+ // gVisor, but may be skipped on native linux.
+
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+
+ msgbuf buf{1, "Test message."};
+ msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0);
+
+ return !(msgrcv(queue.get(), &buf, sizeof(buf.mtext) + 1, 0,
+ MSG_COPY | IPC_NOWAIT) == -1 &&
+ errno == ENOSYS);
+}
+
+// Test msgrcv using MSG_COPY.
+TEST(MsgqueueTest, MsgCopy) {
+ SKIP_IF(!MsgCopySupported());
+
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ msgbuf bufs[5] = {
+ msgbuf{1, "Message 1."}, msgbuf{2, "Message 2."}, msgbuf{3, "Message 3."},
+ msgbuf{4, "Message 4."}, msgbuf{5, "Message 5."},
+ };
+
+ for (auto& buf : bufs) {
+ ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0),
+ SyscallSucceeds());
+ }
+
+ // Receive a copy of the messages.
+ for (size_t i = 0, size = sizeof(bufs) / sizeof(bufs[0]); i < size; i++) {
+ msgbuf buf = bufs[i];
+ msgbuf rcv;
+ EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, i,
+ MSG_COPY | IPC_NOWAIT),
+ SyscallSucceedsWithValue(sizeof(buf.mtext)));
+ EXPECT_TRUE(buf == rcv);
+ }
+
+ // Re-receive the messages normally.
+ for (auto& buf : bufs) {
+ msgbuf rcv;
+ EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, 0, 0),
+ SyscallSucceedsWithValue(sizeof(buf.mtext)));
+ EXPECT_TRUE(buf == rcv);
+ }
+}
+
+// Test msgrcv using MSG_COPY with invalid arguments.
+TEST(MsgqueueTest, MsgCopyInvalidArgs) {
+ SKIP_IF(!MsgCopySupported());
+
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ msgbuf rcv;
+ EXPECT_THAT(msgrcv(queue.get(), &rcv, msgSize, 1, MSG_COPY),
+ SyscallFailsWithErrno(EINVAL));
+
+ EXPECT_THAT(
+ msgrcv(queue.get(), &rcv, msgSize, 5, MSG_COPY | MSG_EXCEPT | IPC_NOWAIT),
+ SyscallFailsWithErrno(EINVAL));
+}
+
+// Test msgrcv using MSG_COPY with invalid indices.
+TEST(MsgqueueTest, MsgCopyInvalidIndex) {
+ SKIP_IF(!MsgCopySupported());
+
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ msgbuf rcv;
+ EXPECT_THAT(msgrcv(queue.get(), &rcv, msgSize, -3, MSG_COPY | IPC_NOWAIT),
+ SyscallFailsWithErrno(ENOMSG));
+
+ EXPECT_THAT(msgrcv(queue.get(), &rcv, msgSize, 5, MSG_COPY | IPC_NOWAIT),
+ SyscallFailsWithErrno(ENOMSG));
+}
+
+// Test msgrcv (most probably) blocking on an empty queue.
+TEST(MsgqueueTest, MsgRcvBlocking) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ msgbuf buf{1, "A message."};
+
+ const pid_t child_pid = fork();
+ if (child_pid == 0) {
+ msgbuf rcv;
+ TEST_PCHECK(RetryEINTR(msgrcv)(queue.get(), &rcv, sizeof(buf.mtext) + 1, 0,
+ 0) == sizeof(buf.mtext) &&
+ buf == rcv);
+ _exit(0);
+ }
+
+ // Sleep to try and make msgrcv block before sending a message.
+ absl::SleepFor(absl::Milliseconds(150));
+
+ EXPECT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0),
+ SyscallSucceeds());
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+}
+
+// Test msgrcv (most probably) waiting for a specific-type message.
+TEST(MsgqueueTest, MsgRcvTypeBlocking) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ msgbuf bufs[5] = {{1, "A message."},
+ {1, "A message."},
+ {1, "A message."},
+ {1, "A message."},
+ {2, "A different message."}};
+
+ const pid_t child_pid = fork();
+ if (child_pid == 0) {
+ msgbuf buf = bufs[4]; // Buffer that should be received.
+ msgbuf rcv;
+ TEST_PCHECK(RetryEINTR(msgrcv)(queue.get(), &rcv, sizeof(buf.mtext) + 1, 2,
+ 0) == sizeof(buf.mtext) &&
+ buf == rcv);
+ _exit(0);
+ }
+
+ // Sleep to try and make msgrcv block before sending messages.
+ absl::SleepFor(absl::Milliseconds(150));
+
+ // Send all buffers in order, only last one should be received.
+ for (auto& buf : bufs) {
+ EXPECT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0),
+ SyscallSucceeds());
+ }
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+}
+
+// Test msgsnd (most probably) blocking on a full queue.
+TEST(MsgqueueTest, MsgSndBlocking) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ // Use a buffer with the maximum mount of bytes that can be transformed to
+ // make it easier to exhaust the queue limit.
+ struct msgmax {
+ int64_t mtype;
+ char mtext[msgMax];
+ };
+
+ msgmax buf{1, ""}; // Has max amount of bytes.
+
+ const size_t msgCount = msgMnb / msgMax; // Number of messages that can be
+ // sent without blocking.
+
+ const pid_t child_pid = fork();
+ if (child_pid == 0) {
+ // Fill the queue.
+ for (size_t i = 0; i < msgCount; i++) {
+ TEST_PCHECK(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0) == 0);
+ }
+
+ // Next msgsnd should block.
+ TEST_PCHECK(RetryEINTR(msgsnd)(queue.get(), &buf, sizeof(buf.mtext), 0) ==
+ 0);
+ _exit(0);
+ }
+
+ // To increase the chance of the last msgsnd blocking before doing a msgrcv,
+ // we use MSG_COPY option to copy the last index in the queue. As long as
+ // MSG_COPY fails, the queue hasn't yet been filled. When MSG_COPY succeeds,
+ // the queue is filled, and most probably, a blocking msgsnd has been made.
+ msgmax rcv;
+ while (msgrcv(queue.get(), &rcv, msgMax, msgCount - 1,
+ MSG_COPY | IPC_NOWAIT) == -1 &&
+ errno == ENOMSG) {
+ }
+
+ // Delay a bit more for the blocking msgsnd.
+ absl::SleepFor(absl::Milliseconds(100));
+
+ EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext), 0, 0),
+ SyscallSucceedsWithValue(sizeof(buf.mtext)));
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+}
+
+// Test removing a queue while a blocking msgsnd is executing.
+TEST(MsgqueueTest, MsgSndRmWhileBlocking) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ // Use a buffer with the maximum mount of bytes that can be transformed to
+ // make it easier to exhaust the queue limit.
+ struct msgmax {
+ int64_t mtype;
+ char mtext[msgMax];
+ };
+
+ const size_t msgCount = msgMnb / msgMax; // Number of messages that can be
+ // sent without blocking.
+ const pid_t child_pid = fork();
+ if (child_pid == 0) {
+ // Fill the queue.
+ msgmax buf{1, ""};
+ for (size_t i = 0; i < msgCount; i++) {
+ EXPECT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0),
+ SyscallSucceeds());
+ }
+
+ // Next msgsnd should block. Because we're repeating on EINTR, msgsnd may
+ // race with msgctl(IPC_RMID) and return EINVAL.
+ TEST_PCHECK(RetryEINTR(msgsnd)(queue.get(), &buf, sizeof(buf.mtext), 0) ==
+ -1 &&
+ (errno == EIDRM || errno == EINVAL));
+ _exit(0);
+ }
+
+ // Similar to MsgSndBlocking, we do this to increase the chance of msgsnd
+ // blocking before removing the queue.
+ msgmax rcv;
+ while (msgrcv(queue.get(), &rcv, msgMax, msgCount - 1,
+ MSG_COPY | IPC_NOWAIT) == -1 &&
+ errno == ENOMSG) {
+ }
+ absl::SleepFor(absl::Milliseconds(100));
+
+ EXPECT_THAT(msgctl(queue.release(), IPC_RMID, nullptr), SyscallSucceeds());
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+}
+
+// Test removing a queue while a blocking msgrcv is executing.
+TEST(MsgqueueTest, MsgRcvRmWhileBlocking) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ const pid_t child_pid = fork();
+ if (child_pid == 0) {
+ // Because we're repeating on EINTR, msgsnd may race with msgctl(IPC_RMID)
+ // and return EINVAL.
+ msgbuf rcv;
+ TEST_PCHECK(RetryEINTR(msgrcv)(queue.get(), &rcv, 1, 2, 0) == -1 &&
+ (errno == EIDRM || errno == EINVAL));
+ _exit(0);
+ }
+
+ // Sleep to try and make msgrcv block before sending messages.
+ absl::SleepFor(absl::Milliseconds(150));
+
+ EXPECT_THAT(msgctl(queue.release(), IPC_RMID, nullptr), SyscallSucceeds());
+
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0),
+ SyscallSucceedsWithValue(child_pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+}
+
+// Test a collection of msgsnd/msgrcv operations in different processes.
+TEST(MsgqueueTest, MsgOpGeneral) {
+ Queue queue(msgget(IPC_PRIVATE, 0600));
+ ASSERT_THAT(queue.get(), SyscallSucceeds());
+
+ // Create 50 sending, and 50 receiving processes. There are only 5 messages to
+ // be sent and received, each with a different type. All messages will be sent
+ // and received equally (10 of each.) By the end of the test all processes
+ // should unblock and return normally.
+ const size_t msgCount = 5;
+ std::map<int64_t, msgbuf> typeToBuf = {{1, msgbuf{1, "Message 1."}},
+ {2, msgbuf{2, "Message 2."}},
+ {3, msgbuf{3, "Message 3."}},
+ {4, msgbuf{4, "Message 4."}},
+ {5, msgbuf{5, "Message 5."}}};
+
+ std::vector<pid_t> children;
+
+ const size_t pCount = 50;
+ for (size_t i = 1; i <= pCount; i++) {
+ const pid_t child_pid = fork();
+ if (child_pid == 0) {
+ msgbuf buf = typeToBuf[(i % msgCount) + 1];
+ msgbuf rcv;
+ TEST_PCHECK(RetryEINTR(msgrcv)(queue.get(), &rcv, sizeof(buf.mtext) + 1,
+ (i % msgCount) + 1,
+ 0) == sizeof(buf.mtext) &&
+ buf == rcv);
+ _exit(0);
+ }
+ children.push_back(child_pid);
+ }
+
+ for (size_t i = 1; i <= pCount; i++) {
+ const pid_t child_pid = fork();
+ if (child_pid == 0) {
+ msgbuf buf = typeToBuf[(i % msgCount) + 1];
+ TEST_PCHECK(RetryEINTR(msgsnd)(queue.get(), &buf, sizeof(buf.mtext), 0) ==
+ 0);
+ _exit(0);
+ }
+ children.push_back(child_pid);
+ }
+
+ for (auto const& pid : children) {
+ int status;
+ ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0),
+ SyscallSucceedsWithValue(pid));
+ EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0);
+ }
+}
+
} // namespace
} // namespace testing
} // namespace gvisor
diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc
index 25b0e63d4..286b3d168 100644
--- a/test/syscalls/linux/prctl.cc
+++ b/test/syscalls/linux/prctl.cc
@@ -214,6 +214,12 @@ TEST(PrctlTest, RootDumpability) {
SyscallFailsWithErrno(EINVAL));
}
+TEST(PrctlTest, SetGetSubreaper) {
+ // Setting subreaper on PID 1 works vacuously because PID 1 is always a
+ // subreaper.
+ EXPECT_THAT(prctl(PR_SET_CHILD_SUBREAPER, 1), SyscallSucceeds());
+}
+
} // namespace
} // namespace testing
diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc
index 78aa73edc..8a4025fed 100644
--- a/test/syscalls/linux/proc.cc
+++ b/test/syscalls/linux/proc.cc
@@ -54,6 +54,8 @@
#include "absl/strings/match.h"
#include "absl/strings/numbers.h"
#include "absl/strings/str_cat.h"
+#include "absl/strings/str_format.h"
+#include "absl/strings/str_join.h"
#include "absl/strings/str_split.h"
#include "absl/strings/string_view.h"
#include "absl/synchronization/mutex.h"
@@ -88,6 +90,7 @@ using ::testing::Gt;
using ::testing::HasSubstr;
using ::testing::IsSupersetOf;
using ::testing::Pair;
+using ::testing::StartsWith;
using ::testing::UnorderedElementsAre;
using ::testing::UnorderedElementsAreArray;
@@ -1622,10 +1625,41 @@ TEST(ProcPidStatusTest, HasBasicFields) {
ASSERT_FALSE(status_str.empty());
const auto status = ASSERT_NO_ERRNO_AND_VALUE(ParseProcStatus(status_str));
- EXPECT_THAT(status, IsSupersetOf({Pair("Name", thread_name),
- Pair("Tgid", absl::StrCat(tgid)),
- Pair("Pid", absl::StrCat(tid)),
- Pair("PPid", absl::StrCat(getppid()))}));
+ EXPECT_THAT(status, IsSupersetOf({
+ Pair("Name", thread_name),
+ Pair("Tgid", absl::StrCat(tgid)),
+ Pair("Pid", absl::StrCat(tid)),
+ Pair("PPid", absl::StrCat(getppid())),
+ }));
+
+ if (!IsRunningWithVFS1()) {
+ uid_t ruid, euid, suid;
+ ASSERT_THAT(getresuid(&ruid, &euid, &suid), SyscallSucceeds());
+ gid_t rgid, egid, sgid;
+ ASSERT_THAT(getresgid(&rgid, &egid, &sgid), SyscallSucceeds());
+ std::vector<gid_t> supplementary_gids;
+ int ngids = getgroups(0, nullptr);
+ supplementary_gids.resize(ngids);
+ ASSERT_THAT(getgroups(ngids, supplementary_gids.data()),
+ SyscallSucceeds());
+
+ EXPECT_THAT(
+ status,
+ IsSupersetOf(std::vector<
+ ::testing::Matcher<std::pair<std::string, std::string>>>{
+ // gVisor doesn't support fsuid/gid, and even if it did there is
+ // no getfsuid/getfsgid().
+ Pair("Uid", StartsWith(absl::StrFormat("%d\t%d\t%d\t", ruid, euid,
+ suid))),
+ Pair("Gid", StartsWith(absl::StrFormat("%d\t%d\t%d\t", rgid, egid,
+ sgid))),
+ // ParseProcStatus strips leading whitespace for each value,
+ // so if the Groups line is empty then the trailing space is
+ // stripped.
+ Pair("Groups",
+ StartsWith(absl::StrJoin(supplementary_gids, " "))),
+ }));
+ }
});
}
diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc
index a5c788346..d5e1ce0cc 100644
--- a/test/syscalls/linux/socket_netlink_route.cc
+++ b/test/syscalls/linux/socket_netlink_route.cc
@@ -44,6 +44,7 @@ namespace {
constexpr uint32_t kSeq = 12345;
+using ::testing::_;
using ::testing::AnyOf;
using ::testing::Eq;
@@ -244,7 +245,7 @@ TEST(NetlinkRouteTest, GetLinkByIndexNotFound) {
req.ifm.ifi_index = 1234590;
EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)),
- PosixErrorIs(ENODEV, ::testing::_));
+ PosixErrorIs(ENODEV, _));
}
TEST(NetlinkRouteTest, GetLinkByNameNotFound) {
@@ -273,7 +274,112 @@ TEST(NetlinkRouteTest, GetLinkByNameNotFound) {
NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len);
EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)),
- PosixErrorIs(ENODEV, ::testing::_));
+ PosixErrorIs(ENODEV, _));
+}
+
+TEST(NetlinkRouteTest, RemoveLoopbackByName) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
+
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ struct rtattr rtattr;
+ char ifname[IFNAMSIZ];
+ char pad[NLMSG_ALIGNTO + RTA_ALIGNTO];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_type = RTM_DELLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+ req.rtattr.rta_type = IFLA_IFNAME;
+ req.rtattr.rta_len = RTA_LENGTH(loopback_link.name.size() + 1);
+ strncpy(req.ifname, loopback_link.name.c_str(), sizeof(req.ifname));
+ req.hdr.nlmsg_len =
+ NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len);
+
+ EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)),
+ PosixErrorIs(ENOTSUP, _));
+}
+
+TEST(NetlinkRouteTest, RemoveLoopbackByIndex) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+ Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink());
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_DELLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+ req.ifm.ifi_index = loopback_link.index;
+
+ EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)),
+ PosixErrorIs(ENOTSUP, _));
+}
+
+TEST(NetlinkRouteTest, RemoveLinkByIndexNotFound) {
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_len = sizeof(req);
+ req.hdr.nlmsg_type = RTM_GETLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+ req.ifm.ifi_index = 1234590;
+
+ EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)),
+ PosixErrorIs(ENODEV, _));
+}
+
+TEST(NetlinkRouteTest, RemoveLinkByNameNotFound) {
+ const std::string name = "nodevice?!";
+
+ SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN)));
+ FileDescriptor fd =
+ ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE));
+
+ struct request {
+ struct nlmsghdr hdr;
+ struct ifinfomsg ifm;
+ struct rtattr rtattr;
+ char ifname[IFNAMSIZ];
+ char pad[NLMSG_ALIGNTO + RTA_ALIGNTO];
+ };
+
+ struct request req = {};
+ req.hdr.nlmsg_type = RTM_DELLINK;
+ req.hdr.nlmsg_flags = NLM_F_REQUEST;
+ req.hdr.nlmsg_seq = kSeq;
+ req.ifm.ifi_family = AF_UNSPEC;
+ req.rtattr.rta_type = IFLA_IFNAME;
+ req.rtattr.rta_len = RTA_LENGTH(name.size() + 1);
+ strncpy(req.ifname, name.c_str(), sizeof(req.ifname));
+ req.hdr.nlmsg_len =
+ NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len);
+
+ EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)),
+ PosixErrorIs(ENODEV, _));
}
TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) {
@@ -295,7 +401,7 @@ TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) {
req.ifm.ifi_family = AF_UNSPEC;
EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)),
- PosixErrorIs(EOPNOTSUPP, ::testing::_));
+ PosixErrorIs(EOPNOTSUPP, _));
}
TEST(NetlinkRouteTest, MsgHdrMsgTrunc) {
@@ -536,7 +642,7 @@ TEST(NetlinkRouteTest, AddAndRemoveAddr) {
// Second delete should fail, as address no longer exists.
EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET,
/*prefixlen=*/24, &addr, sizeof(addr)),
- PosixErrorIs(EADDRNOTAVAIL, ::testing::_));
+ PosixErrorIs(EADDRNOTAVAIL, _));
});
// Replace an existing address should succeed.
@@ -546,7 +652,7 @@ TEST(NetlinkRouteTest, AddAndRemoveAddr) {
// Create exclusive should fail, as we created the address above.
EXPECT_THAT(LinkAddExclusiveLocalAddr(loopback_link.index, AF_INET,
/*prefixlen=*/24, &addr, sizeof(addr)),
- PosixErrorIs(EEXIST, ::testing::_));
+ PosixErrorIs(EEXIST, _));
}
// GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request.
diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc
index 72f888659..19dc80d0c 100644
--- a/test/syscalls/linux/stat.cc
+++ b/test/syscalls/linux/stat.cc
@@ -765,7 +765,7 @@ TEST_F(StatTest, StatxSymlink) {
SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 &&
errno == ENOSYS);
- std::string parent_dir = "/tmp";
+ std::string parent_dir = GetAbsoluteTestTmpdir();
TempPath link = ASSERT_NO_ERRNO_AND_VALUE(
TempPath::CreateSymlinkTo(parent_dir, test_file_name_));
std::string p = link.path();
diff --git a/test/syscalls/linux/statfs.cc b/test/syscalls/linux/statfs.cc
index d4ea8e026..d057cdc09 100644
--- a/test/syscalls/linux/statfs.cc
+++ b/test/syscalls/linux/statfs.cc
@@ -28,7 +28,7 @@ namespace testing {
namespace {
TEST(StatfsTest, CannotStatBadPath) {
- auto temp_file = NewTempAbsPathInDir("/tmp");
+ auto temp_file = NewTempAbsPath();
struct statfs st;
EXPECT_THAT(statfs(temp_file.c_str(), &st), SyscallFailsWithErrno(ENOENT));
diff --git a/tools/checklinkname/BUILD b/tools/checklinkname/BUILD
new file mode 100644
index 000000000..0f1b07e24
--- /dev/null
+++ b/tools/checklinkname/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "checklinkname",
+ srcs = [
+ "check_linkname.go",
+ "known.go",
+ ],
+ nogo = False,
+ visibility = ["//tools/nogo:__subpackages__"],
+ deps = [
+ "@org_golang_x_tools//go/analysis:go_default_library",
+ ],
+)
diff --git a/tools/checklinkname/README.md b/tools/checklinkname/README.md
new file mode 100644
index 000000000..06b3c302d
--- /dev/null
+++ b/tools/checklinkname/README.md
@@ -0,0 +1,54 @@
+# `checklinkname` Analyzer
+
+`checklinkname` is an analyzer to provide rudimentary type-checking for
+`//go:linkname` directives. Since `//go:linkname` only affects linker behavior,
+there is no built-in type safety and it is the programmer's responsibility to
+ensure the types on either side are compatible.
+
+`checklinkname` helps with this by checking that uses match expectations, as
+defined in this package.
+
+`known.go` contains the set of known linkname targets. For most functions, we
+expect identical types on both sides of the linkname. In a few cases, the types
+may be slightly different (e.g., local redefinition of internal type). It is
+still the responsibility of the programmer to ensure the signatures in
+`known.go` are compatible and safe.
+
+## Findings
+
+Here are the most common findings from this package, and how to resolve them.
+
+### `runtime.foo signature got "BAR" want "BAZ"; stdlib type changed?`
+
+The definition of `runtime.foo` in the standard library does not match the
+expected type in `known.go`. This means that the function signature in the
+standard library changed.
+
+Addressing this will require creating a new linkname directive in a new Go
+version build-tagged in any packages using this symbol. Be sure to also check to
+ensure use with the new version is safe, as function constraints may have
+changed in addition to the signature.
+
+<!-- TODO(b/165820485): This isn't yet explicitly supported. -->
+
+`known.go` will also need to be updated to accept the new signature for the new
+version of Go.
+
+### `Cannot find known symbol "runtime.foo"`
+
+The standard library has removed runtime.foo entirely. Handling is similar to
+above, except existing code must transition away from the symbol entirely (note
+that is may simply be renamed).
+
+### `linkname to unknown symbol "mypkg.foo"; add this symbol to checklinkname.knownLinknames type-check against the remote type`
+
+A package has added a new linkname directive for a symbol not listed in
+`known.go`. Address this by adding a new entry for the target symbol. The
+`local` field should be the expected type in your package, while `remote` should
+be expected type in the remote package (e.g., in the standard library). These
+are typically identical, in which case `remote` can be omitted.
+
+### `usage: //go:linkname localname [linkname]`
+
+Malformed `//go:linkname` directive. This should be accompanied by a build
+failure in the package.
diff --git a/tools/checklinkname/check_linkname.go b/tools/checklinkname/check_linkname.go
new file mode 100644
index 000000000..5373dd762
--- /dev/null
+++ b/tools/checklinkname/check_linkname.go
@@ -0,0 +1,229 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package checklinkname ensures that linkname declarations match their source.
+package checklinkname
+
+import (
+ "fmt"
+ "go/ast"
+ "go/token"
+ "go/types"
+ "strings"
+
+ "golang.org/x/tools/go/analysis"
+)
+
+// Analyzer implements the checklinkname analyzer.
+var Analyzer = &analysis.Analyzer{
+ Name: "checklinkname",
+ Doc: "verifies that linkname declarations match their source",
+ Run: run,
+}
+
+// go:linkname can be rather confusing. https://pkg.go.dev/cmd/compile says:
+//
+// //go:linkname localname [importpath.name]
+//
+// This special directive does not apply to the Go code that follows it.
+// Instead, the //go:linkname directive instructs the compiler to use
+// “importpath.name” as the object file symbol name for the variable or
+// function declared as “localname” in the source code. If the
+// “importpath.name” argument is omitted, the directive uses the symbol's
+// default object file symbol name and only has the effect of making the symbol
+// accessible to other packages. Because this directive can subvert the type
+// system and package modularity, it is only enabled in files that have
+// imported "unsafe".
+//
+// In this package we use the term "local" to refer to the symbol name in the
+// same package as the //go:linkname directive, whose name will be changed by
+// the linker. We use the term "remote" to refer to the symbol name that we are
+// changing to.
+//
+// In the general case, the local symbol is a function declaration, and the
+// remote symbol is a real function in the standard library.
+
+// linknameSignatures describes a the type signatures of the symbols in a
+// //go:linkname directive.
+type linknameSignatures struct {
+ local string
+ remote string // equivalent to local if "".
+}
+
+func (l *linknameSignatures) Remote() string {
+ if l.remote == "" {
+ return l.local
+ }
+ return l.remote
+}
+
+// linknameSymbols describes the symbol namess in a single //go:linkname
+// directive.
+type linknameSymbols struct {
+ pos token.Pos
+ local string
+ remote string
+}
+
+func findLinknames(pass *analysis.Pass, f *ast.File) []linknameSymbols {
+ var names []linknameSymbols
+
+ for _, cg := range f.Comments {
+ for _, c := range cg.List {
+ if len(c.Text) <= 2 || !strings.HasPrefix(c.Text[2:], "go:linkname ") {
+ continue
+ }
+
+ f := strings.Fields(c.Text)
+ if len(f) < 2 || len(f) > 3 {
+ // Malformed linkname. This is the same error the compiler emits.
+ pass.Reportf(c.Slash, "usage: //go:linkname localname [linkname]")
+ }
+
+ if len(f) == 2 {
+ // "If the “importpath.name” argument is
+ // omitted, the directive uses the symbol's
+ // default object file symbol name and only has
+ // the effect of making the symbol accessible
+ // to other packages."
+ // -https://golang.org/cmd/compile
+ //
+ // There is no type-checking to be done here.
+ continue
+ }
+
+ names = append(names, linknameSymbols{
+ pos: c.Slash,
+ local: f[1],
+ remote: f[2],
+ })
+ }
+ }
+
+ return names
+}
+
+func splitSymbol(pkg *types.Package, symbol string) (packagePath, name string) {
+ // Note that some runtime symbols can have multiple dots. e.g.,
+ // runtime..init_task.
+ s := strings.SplitN(symbol, ".", 2)
+
+ switch len(s) {
+ case 1:
+ // Package name omitted, use current package.
+ return pkg.Path(), symbol
+ case 2:
+ return s[0], s[1]
+ default:
+ panic("unreachable")
+ }
+}
+
+func findObject(pkg *types.Package, symbol string) (types.Object, error) {
+ packagePath, symbolName := splitSymbol(pkg, symbol)
+ return findPackageObject(pkg, packagePath, symbolName)
+}
+
+func findPackageObject(pkg *types.Package, packagePath, symbolName string) (types.Object, error) {
+ if pkg.Path() == packagePath {
+ o := pkg.Scope().Lookup(symbolName)
+ if o == nil {
+ return nil, fmt.Errorf("%q not found in %q (names: %+v)", symbolName, packagePath, pkg.Scope().Names())
+ }
+ return o, nil
+ }
+
+ for _, p := range pkg.Imports() {
+ if o, err := findPackageObject(p, packagePath, symbolName); err == nil {
+ return o, nil
+ }
+ }
+
+ return nil, fmt.Errorf("package %q not found", packagePath)
+}
+
+// checkOneLinkname verifies that the type of sym.local matches the type from
+// knownLinknames.
+func checkOneLinkname(pass *analysis.Pass, f *ast.File, sym linknameSymbols) {
+ remotePackage, remoteName := splitSymbol(pass.Pkg, sym.remote)
+
+ m, ok := knownLinknames[remotePackage]
+ if !ok {
+ pass.Reportf(sym.pos, "linkname to unknown symbol %q; add this symbol to checklinkname.knownLinknames type-check against the remote type", sym.remote)
+ return
+ }
+
+ linkname, ok := m[remoteName]
+ if !ok {
+ pass.Reportf(sym.pos, "linkname to unknown symbol %q; add this symbol to checklinkname.knownLinknames type-check against the remote type", sym.remote)
+ return
+ }
+
+ local, err := findObject(pass.Pkg, sym.local)
+ if err != nil {
+ pass.Reportf(sym.pos, "Unable to find symbol %q: %v", sym.local, err)
+ return
+ }
+
+ localSig, ok := local.Type().(*types.Signature)
+ if !ok {
+ pass.Reportf(local.Pos(), "%q object is not a signature: %+#v", sym.local, local)
+ return
+ }
+
+ if linkname.local != localSig.String() {
+ pass.Reportf(local.Pos(), "%q signature got %q want %q; mismatched types?", sym.local, localSig.String(), linkname.local)
+ return
+ }
+}
+
+// checkOneRemote verifies that the type of sym matches wantSig.
+func checkOneRemote(pass *analysis.Pass, sym, wantSig string) {
+ o := pass.Pkg.Scope().Lookup(sym)
+ if o == nil {
+ pass.Reportf(pass.Files[0].Package, "Cannot find known symbol %q", sym)
+ return
+ }
+
+ sig, ok := o.Type().(*types.Signature)
+ if !ok {
+ pass.Reportf(o.Pos(), "%q object is not a signature: %+#v", sym, o)
+ return
+ }
+
+ if sig.String() != wantSig {
+ pass.Reportf(o.Pos(), "%q signature got %q want %q; stdlib type changed?", sym, sig.String(), wantSig)
+ return
+ }
+}
+
+func run(pass *analysis.Pass) (interface{}, error) {
+ // First, check if any remote symbols are in this package.
+ p, ok := knownLinknames[pass.Pkg.Path()]
+ if ok {
+ for sym, l := range p {
+ checkOneRemote(pass, sym, l.Remote())
+ }
+ }
+
+ // Then check for local //go:linkname directives in this package.
+ for _, f := range pass.Files {
+ names := findLinknames(pass, f)
+ for _, n := range names {
+ checkOneLinkname(pass, f, n)
+ }
+ }
+
+ return nil, nil
+}
diff --git a/tools/checklinkname/known.go b/tools/checklinkname/known.go
new file mode 100644
index 000000000..54e5155fc
--- /dev/null
+++ b/tools/checklinkname/known.go
@@ -0,0 +1,110 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package checklinkname
+
+// knownLinknames is the set of the symbols for which we can do a rudimentary
+// type-check on.
+//
+// When analyzing the remote package (e.g., runtime), we verify the symbol
+// signature matches 'remote'. When analyzing local packages with //go:linkname
+// directives, we verify the symbol signature matches 'local'.
+//
+// Usually these are identical, but may differ slightly if equivalent
+// replacement types are used in the local packages, such as a copy of a struct
+// or uintptr instead of a pointer type.
+//
+// NOTE: It is the responsibility of the developer to verify the safety of the
+// signatures used here! This analyzer only checks that types match this map;
+// it does not verify compatibility of the entries themselves.
+//
+// //go:linkname directives with no corresponding entry here will trigger a
+// finding.
+//
+// We preform only rudimentary string-based type-checking due to limitations in
+// the analysis framework. Ideally, from the local package we'd lookup the
+// remote symbol's types.Object and perform robust type-checking.
+// Unfortunately, remote symbols are typically loaded from the remote package's
+// gcexportdata. Since //go:linkname targets are usually not exported symbols,
+// they are no included in gcexportdata and we cannot load their types.Object.
+//
+// TODO(b/165820485): Add option to specific per-version signatures.
+var knownLinknames = map[string]map[string]linknameSignatures{
+ "runtime": map[string]linknameSignatures{
+ "entersyscall": linknameSignatures{
+ local: "func()",
+ },
+ "entersyscallblock": linknameSignatures{
+ local: "func()",
+ },
+ "exitsyscall": linknameSignatures{
+ local: "func()",
+ },
+ "fastrand": linknameSignatures{
+ local: "func() uint32",
+ },
+ "gopark": linknameSignatures{
+ // TODO(b/165820485): add verification of waitReason
+ // size and reason and traceEv values.
+ local: "func(unlockf func(uintptr, unsafe.Pointer) bool, lock unsafe.Pointer, reason uint8, traceEv byte, traceskip int)",
+ remote: "func(unlockf func(*runtime.g, unsafe.Pointer) bool, lock unsafe.Pointer, reason runtime.waitReason, traceEv byte, traceskip int)",
+ },
+ "goready": linknameSignatures{
+ local: "func(gp uintptr, traceskip int)",
+ remote: "func(gp *runtime.g, traceskip int)",
+ },
+ "goyield": linknameSignatures{
+ local: "func()",
+ },
+ "memmove": linknameSignatures{
+ local: "func(to unsafe.Pointer, from unsafe.Pointer, n uintptr)",
+ },
+ "throw": linknameSignatures{
+ local: "func(s string)",
+ },
+ },
+ "sync": map[string]linknameSignatures{
+ "runtime_canSpin": linknameSignatures{
+ local: "func(i int) bool",
+ },
+ "runtime_doSpin": linknameSignatures{
+ local: "func()",
+ },
+ "runtime_Semacquire": linknameSignatures{
+ // The only difference here is the parameter names. We
+ // can't just change our local use to match remote, as
+ // the stdlib runtime and sync packages also disagree
+ // on the name, and the analyzer checks that use as
+ // well.
+ local: "func(addr *uint32)",
+ remote: "func(s *uint32)",
+ },
+ "runtime_Semrelease": linknameSignatures{
+ // See above.
+ local: "func(addr *uint32, handoff bool, skipframes int)",
+ remote: "func(s *uint32, handoff bool, skipframes int)",
+ },
+ },
+ "syscall": map[string]linknameSignatures{
+ "runtime_BeforeFork": linknameSignatures{
+ local: "func()",
+ },
+ "runtime_AfterFork": linknameSignatures{
+ local: "func()",
+ },
+ "runtime_AfterForkInChild": linknameSignatures{
+ local: "func()",
+ },
+ },
+}
diff --git a/tools/checklinkname/test/BUILD b/tools/checklinkname/test/BUILD
new file mode 100644
index 000000000..b29bd84f2
--- /dev/null
+++ b/tools/checklinkname/test/BUILD
@@ -0,0 +1,9 @@
+load("//tools:defs.bzl", "go_library")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "test",
+ testonly = 1,
+ srcs = ["test_unsafe.go"],
+)
diff --git a/tools/checklinkname/test/test_unsafe.go b/tools/checklinkname/test/test_unsafe.go
new file mode 100644
index 000000000..a7504591c
--- /dev/null
+++ b/tools/checklinkname/test/test_unsafe.go
@@ -0,0 +1,34 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package test provides linkname test targets.
+package test
+
+import (
+ _ "unsafe" // for go:linkname.
+)
+
+//go:linkname DetachedLinkname runtime.fastrand
+
+//go:linkname attachedLinkname runtime.entersyscall
+func attachedLinkname()
+
+// AttachedLinkname reexports attachedLinkname because go vet doesn't like an
+// exported go:linkname without a comment starting with "// AttachedLinkname".
+func AttachedLinkname() {
+ attachedLinkname()
+}
+
+// DetachedLinkname has a linkname elsewhere in the file.
+func DetachedLinkname() uint32
diff --git a/tools/constraintutil/BUILD b/tools/constraintutil/BUILD
new file mode 100644
index 000000000..004b708c4
--- /dev/null
+++ b/tools/constraintutil/BUILD
@@ -0,0 +1,18 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "constraintutil",
+ srcs = ["constraintutil.go"],
+ marshal = False,
+ stateify = False,
+ visibility = ["//tools:__subpackages__"],
+)
+
+go_test(
+ name = "constraintutil_test",
+ size = "small",
+ srcs = ["constraintutil_test.go"],
+ library = ":constraintutil",
+)
diff --git a/tools/constraintutil/constraintutil.go b/tools/constraintutil/constraintutil.go
new file mode 100644
index 000000000..fb3fbe5c2
--- /dev/null
+++ b/tools/constraintutil/constraintutil.go
@@ -0,0 +1,169 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package constraintutil provides utilities for working with Go build
+// constraints.
+package constraintutil
+
+import (
+ "bufio"
+ "bytes"
+ "fmt"
+ "go/build/constraint"
+ "io"
+ "os"
+ "strings"
+)
+
+// FromReader extracts the build constraint from the Go source or assembly file
+// whose contents are read by r.
+func FromReader(r io.Reader) (constraint.Expr, error) {
+ // See go/build.parseFileHeader() for the "official" logic that this is
+ // derived from.
+ const (
+ slashStar = "/*"
+ starSlash = "*/"
+ gobuildPrefix = "//go:build"
+ )
+ s := bufio.NewScanner(r)
+ var (
+ inSlashStar = false // between /* and */
+ haveGobuild = false
+ e constraint.Expr
+ )
+Lines:
+ for s.Scan() {
+ line := bytes.TrimSpace(s.Bytes())
+ if !inSlashStar && constraint.IsGoBuild(string(line)) {
+ if haveGobuild {
+ return nil, fmt.Errorf("multiple go:build directives")
+ }
+ haveGobuild = true
+ var err error
+ e, err = constraint.Parse(string(line))
+ if err != nil {
+ return nil, err
+ }
+ }
+ ThisLine:
+ for len(line) > 0 {
+ if inSlashStar {
+ if i := bytes.Index(line, []byte(starSlash)); i >= 0 {
+ inSlashStar = false
+ line = bytes.TrimSpace(line[i+len(starSlash):])
+ continue ThisLine
+ }
+ continue Lines
+ }
+ if bytes.HasPrefix(line, []byte("//")) {
+ continue Lines
+ }
+ // Note that if /* appears in the line, but not at the beginning,
+ // then the line is still non-empty, so skipping this and
+ // terminating below is correct.
+ if bytes.HasPrefix(line, []byte(slashStar)) {
+ inSlashStar = true
+ line = bytes.TrimSpace(line[len(slashStar):])
+ continue ThisLine
+ }
+ // A non-empty non-comment line terminates scanning for go:build.
+ break Lines
+ }
+ }
+ return e, s.Err()
+}
+
+// FromString extracts the build constraint from the Go source or assembly file
+// containing the given data. If no build constraint applies to the file, it
+// returns nil.
+func FromString(str string) (constraint.Expr, error) {
+ return FromReader(strings.NewReader(str))
+}
+
+// FromFile extracts the build constraint from the Go source or assembly file
+// at the given path. If no build constraint applies to the file, it returns
+// nil.
+func FromFile(path string) (constraint.Expr, error) {
+ f, err := os.Open(path)
+ if err != nil {
+ return nil, err
+ }
+ defer f.Close()
+ return FromReader(f)
+}
+
+// Combine returns a constraint.Expr that evaluates to true iff all expressions
+// in es evaluate to true. If es is empty, Combine returns nil.
+//
+// Preconditions: All constraint.Exprs in es are non-nil.
+func Combine(es []constraint.Expr) constraint.Expr {
+ switch len(es) {
+ case 0:
+ return nil
+ case 1:
+ return es[0]
+ default:
+ a := &constraint.AndExpr{es[0], es[1]}
+ for i := 2; i < len(es); i++ {
+ a = &constraint.AndExpr{a, es[i]}
+ }
+ return a
+ }
+}
+
+// CombineFromFiles returns a build constraint expression that evaluates to
+// true iff the build constraints from all of the given Go source or assembly
+// files evaluate to true. If no build constraints apply to any of the given
+// files, it returns nil.
+func CombineFromFiles(paths []string) (constraint.Expr, error) {
+ var es []constraint.Expr
+ for _, path := range paths {
+ e, err := FromFile(path)
+ if err != nil {
+ return nil, fmt.Errorf("failed to read build constraints from %q: %v", path, err)
+ }
+ if e != nil {
+ es = append(es, e)
+ }
+ }
+ return Combine(es), nil
+}
+
+// Lines returns a string containing build constraint directives for the given
+// constraint.Expr, including two trailing newlines, as appropriate for a Go
+// source or assembly file. At least a go:build directive will be emitted; if
+// the constraint is expressible using +build directives as well, then +build
+// directives will also be emitted.
+//
+// If e is nil, Lines returns the empty string.
+func Lines(e constraint.Expr) string {
+ if e == nil {
+ return ""
+ }
+
+ var b strings.Builder
+ b.WriteString("//go:build ")
+ b.WriteString(e.String())
+ b.WriteByte('\n')
+
+ if pblines, err := constraint.PlusBuildLines(e); err == nil {
+ for _, line := range pblines {
+ b.WriteString(line)
+ b.WriteByte('\n')
+ }
+ }
+
+ b.WriteByte('\n')
+ return b.String()
+}
diff --git a/tools/constraintutil/constraintutil_test.go b/tools/constraintutil/constraintutil_test.go
new file mode 100644
index 000000000..eeabd8dcf
--- /dev/null
+++ b/tools/constraintutil/constraintutil_test.go
@@ -0,0 +1,138 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package constraintutil
+
+import (
+ "go/build/constraint"
+ "testing"
+)
+
+func TestFileParsing(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ data string
+ expr string
+ }{
+ {
+ name: "Empty",
+ },
+ {
+ name: "NoConstraint",
+ data: "// copyright header\n\npackage main",
+ },
+ {
+ name: "ConstraintOnFirstLine",
+ data: "//go:build amd64\n#include \"textflag.h\"",
+ expr: "amd64",
+ },
+ {
+ name: "ConstraintAfterSlashSlashComment",
+ data: "// copyright header\n\n//go:build linux\n\npackage newlib",
+ expr: "linux",
+ },
+ {
+ name: "ConstraintAfterSlashStarComment",
+ data: "/*\ncopyright header\n*/\n\n//go:build !race\n\npackage oldlib",
+ expr: "!race",
+ },
+ {
+ name: "ConstraintInSlashSlashComment",
+ data: "// blah blah //go:build windows",
+ },
+ {
+ name: "ConstraintInSlashStarComment",
+ data: "/*\n//go:build windows\n*/",
+ },
+ {
+ name: "ConstraintAfterPackageClause",
+ data: "package oops\n//go:build race",
+ },
+ {
+ name: "ConstraintAfterCppInclude",
+ data: "#include \"textflag.h\"\n//go:build arm64",
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ e, err := FromString(test.data)
+ if err != nil {
+ t.Fatalf("FromString(%q) failed: %v", test.data, err)
+ }
+ if e == nil {
+ if len(test.expr) != 0 {
+ t.Errorf("FromString(%q): got no constraint, wanted %q", test.data, test.expr)
+ }
+ } else {
+ got := e.String()
+ if len(test.expr) == 0 {
+ t.Errorf("FromString(%q): got %q, wanted no constraint", test.data, got)
+ } else if got != test.expr {
+ t.Errorf("FromString(%q): got %q, wanted %q", test.data, got, test.expr)
+ }
+ }
+ })
+ }
+}
+
+func TestCombine(t *testing.T) {
+ for _, test := range []struct {
+ name string
+ in []string
+ out string
+ }{
+ {
+ name: "0",
+ },
+ {
+ name: "1",
+ in: []string{"amd64 || arm64"},
+ out: "amd64 || arm64",
+ },
+ {
+ name: "2",
+ in: []string{"amd64", "amd64 && linux"},
+ out: "amd64 && amd64 && linux",
+ },
+ {
+ name: "3",
+ in: []string{"amd64", "amd64 || arm64", "amd64 || riscv64"},
+ out: "amd64 && (amd64 || arm64) && (amd64 || riscv64)",
+ },
+ } {
+ t.Run(test.name, func(t *testing.T) {
+ inexprs := make([]constraint.Expr, 0, len(test.in))
+ for _, estr := range test.in {
+ line := "//go:build " + estr
+ e, err := constraint.Parse(line)
+ if err != nil {
+ t.Fatalf("constraint.Parse(%q) failed: %v", line, err)
+ }
+ inexprs = append(inexprs, e)
+ }
+ outexpr := Combine(inexprs)
+ if outexpr == nil {
+ if len(test.out) != 0 {
+ t.Errorf("Combine(%v): got no constraint, wanted %q", test.in, test.out)
+ }
+ } else {
+ got := outexpr.String()
+ if len(test.out) == 0 {
+ t.Errorf("Combine(%v): got %q, wanted no constraint", test.in, got)
+ } else if got != test.out {
+ t.Errorf("Combine(%v): got %q, wanted %q", test.in, got, test.out)
+ }
+ }
+ })
+ }
+}
diff --git a/tools/go_generics/go_merge/BUILD b/tools/go_generics/go_merge/BUILD
index 5e0487e93..211e6b3ed 100644
--- a/tools/go_generics/go_merge/BUILD
+++ b/tools/go_generics/go_merge/BUILD
@@ -7,6 +7,6 @@ go_binary(
srcs = ["main.go"],
visibility = ["//:sandbox"],
deps = [
- "//tools/tags",
+ "//tools/constraintutil",
],
)
diff --git a/tools/go_generics/go_merge/main.go b/tools/go_generics/go_merge/main.go
index 801f2354f..81394ddce 100644
--- a/tools/go_generics/go_merge/main.go
+++ b/tools/go_generics/go_merge/main.go
@@ -25,9 +25,8 @@ import (
"os"
"path/filepath"
"strconv"
- "strings"
- "gvisor.dev/gvisor/tools/tags"
+ "gvisor.dev/gvisor/tools/constraintutil"
)
var (
@@ -131,6 +130,12 @@ func main() {
}
f.Decls = newDecls
+ // Infer build constraints for the output file.
+ bcexpr, err := constraintutil.CombineFromFiles(flag.Args())
+ if err != nil {
+ fatalf("Failed to read build constraints: %v\n", err)
+ }
+
// Write the output file.
var buf bytes.Buffer
if err := format.Node(&buf, fset, f); err != nil {
@@ -141,9 +146,7 @@ func main() {
fatalf("opening output: %v\n", err)
}
defer outf.Close()
- if t := tags.Aggregate(flag.Args()); len(t) > 0 {
- fmt.Fprintf(outf, "%s\n\n", strings.Join(t.Lines(), "\n"))
- }
+ outf.WriteString(constraintutil.Lines(bcexpr))
if _, err := outf.Write(buf.Bytes()); err != nil {
fatalf("write: %v\n", err)
}
diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD
index c2747d94c..aaa203115 100644
--- a/tools/go_marshal/gomarshal/BUILD
+++ b/tools/go_marshal/gomarshal/BUILD
@@ -18,5 +18,5 @@ go_library(
visibility = [
"//:sandbox",
],
- deps = ["//tools/tags"],
+ deps = ["//tools/constraintutil"],
)
diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go
index 00961c90d..4c23637c0 100644
--- a/tools/go_marshal/gomarshal/generator.go
+++ b/tools/go_marshal/gomarshal/generator.go
@@ -25,7 +25,7 @@ import (
"sort"
"strings"
- "gvisor.dev/gvisor/tools/tags"
+ "gvisor.dev/gvisor/tools/constraintutil"
)
// List of identifiers we use in generated code that may conflict with a
@@ -123,16 +123,18 @@ func (g *Generator) writeHeader() error {
var b sourceBuffer
b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n")
- // Emit build tags.
- b.emit("// If there are issues with build tag aggregation, see\n")
- b.emit("// tools/go_marshal/gomarshal/generator.go:writeHeader(). The build tags here\n")
- b.emit("// come from the input set of files used to generate this file. This input set\n")
- b.emit("// is filtered based on pre-defined file suffixes related to build tags, see \n")
- b.emit("// tools/defs.bzl:calculate_sets().\n\n")
-
- if t := tags.Aggregate(g.inputs); len(t) > 0 {
- b.emit(strings.Join(t.Lines(), "\n"))
- b.emit("\n\n")
+ bcexpr, err := constraintutil.CombineFromFiles(g.inputs)
+ if err != nil {
+ return err
+ }
+ if bcexpr != nil {
+ // Emit build constraints.
+ b.emit("// If there are issues with build constraint aggregation, see\n")
+ b.emit("// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here\n")
+ b.emit("// come from the input set of files used to generate this file. This input set\n")
+ b.emit("// is filtered based on pre-defined file suffixes related to build constraints,\n")
+ b.emit("// see tools/defs.bzl:calculate_sets().\n\n")
+ b.emit(constraintutil.Lines(bcexpr))
}
// Package header.
@@ -553,11 +555,12 @@ func (g *Generator) writeTests(ts []*testGenerator) error {
b.reset()
b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n")
- // Emit build tags.
- if t := tags.Aggregate(g.inputs); len(t) > 0 {
- b.emit(strings.Join(t.Lines(), "\n"))
- b.emit("\n\n")
+ // Emit build constraints.
+ bcexpr, err := constraintutil.CombineFromFiles(g.inputs)
+ if err != nil {
+ return err
}
+ b.emit(constraintutil.Lines(bcexpr))
b.emit("package %s\n\n", g.pkg)
if err := b.write(g.outputTest); err != nil {
diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD
index 913558b4e..ad66981c7 100644
--- a/tools/go_stateify/BUILD
+++ b/tools/go_stateify/BUILD
@@ -6,7 +6,7 @@ go_binary(
name = "stateify",
srcs = ["main.go"],
visibility = ["//:sandbox"],
- deps = ["//tools/tags"],
+ deps = ["//tools/constraintutil"],
)
bzl_library(
diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go
index 93022f504..7216388a0 100644
--- a/tools/go_stateify/main.go
+++ b/tools/go_stateify/main.go
@@ -28,7 +28,7 @@ import (
"strings"
"sync"
- "gvisor.dev/gvisor/tools/tags"
+ "gvisor.dev/gvisor/tools/constraintutil"
)
var (
@@ -214,10 +214,13 @@ func main() {
// Automated warning.
fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n")
- // Emit build tags.
- if t := tags.Aggregate(flag.Args()); len(t) > 0 {
- fmt.Fprintf(outputFile, "%s\n\n", strings.Join(t.Lines(), "\n"))
+ // Emit build constraints.
+ bcexpr, err := constraintutil.CombineFromFiles(flag.Args())
+ if err != nil {
+ fmt.Fprintf(os.Stderr, "Failed to infer build constraints: %v", err)
+ os.Exit(1)
}
+ outputFile.WriteString(constraintutil.Lines(bcexpr))
// Emit the package name.
_, pkg := filepath.Split(*fullPkg)
diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD
index 27fe48680..d72821377 100644
--- a/tools/nogo/BUILD
+++ b/tools/nogo/BUILD
@@ -35,6 +35,7 @@ go_library(
visibility = ["//:sandbox"],
deps = [
"//tools/checkescape",
+ "//tools/checklinkname",
"//tools/checklocks",
"//tools/checkunsafe",
"//tools/nogo/objdump",
diff --git a/tools/nogo/analyzers.go b/tools/nogo/analyzers.go
index 2b3c03fec..6705fc905 100644
--- a/tools/nogo/analyzers.go
+++ b/tools/nogo/analyzers.go
@@ -47,6 +47,7 @@ import (
"honnef.co/go/tools/stylecheck"
"gvisor.dev/gvisor/tools/checkescape"
+ "gvisor.dev/gvisor/tools/checklinkname"
"gvisor.dev/gvisor/tools/checklocks"
"gvisor.dev/gvisor/tools/checkunsafe"
)
@@ -80,6 +81,7 @@ var AllAnalyzers = []*analysis.Analyzer{
unusedresult.Analyzer,
checkescape.Analyzer,
checkunsafe.Analyzer,
+ checklinkname.Analyzer,
checklocks.Analyzer,
}
diff --git a/tools/nogo/filter/main.go b/tools/nogo/filter/main.go
index d50336b9b..4a925d03c 100644
--- a/tools/nogo/filter/main.go
+++ b/tools/nogo/filter/main.go
@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-// Binary check is the nogo entrypoint.
+// Binary filter is the filters and reports nogo findings.
package main
import (
diff --git a/tools/tags/BUILD b/tools/tags/BUILD
deleted file mode 100644
index 1c02e2c89..000000000
--- a/tools/tags/BUILD
+++ /dev/null
@@ -1,11 +0,0 @@
-load("//tools:defs.bzl", "go_library")
-
-package(licenses = ["notice"])
-
-go_library(
- name = "tags",
- srcs = ["tags.go"],
- marshal = False,
- stateify = False,
- visibility = ["//tools:__subpackages__"],
-)
diff --git a/tools/tags/tags.go b/tools/tags/tags.go
deleted file mode 100644
index f35904e0a..000000000
--- a/tools/tags/tags.go
+++ /dev/null
@@ -1,89 +0,0 @@
-// Copyright 2020 The gVisor Authors.
-//
-// Licensed under the Apache License, Version 2.0 (the "License");
-// you may not use this file except in compliance with the License.
-// You may obtain a copy of the License at
-//
-// http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing, software
-// distributed under the License is distributed on an "AS IS" BASIS,
-// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-// See the License for the specific language governing permissions and
-// limitations under the License.
-
-// Package tags is a utility for parsing build tags.
-package tags
-
-import (
- "fmt"
- "io/ioutil"
- "strings"
-)
-
-// OrSet is a set of tags on a single line.
-//
-// Note that tags may include ",", and we don't distinguish this case in the
-// logic below. Ideally, this constraints can be split into separate top-level
-// build tags in order to resolve any issues.
-type OrSet []string
-
-// Line returns the line for this or.
-func (or OrSet) Line() string {
- return fmt.Sprintf("// +build %s", strings.Join([]string(or), " "))
-}
-
-// AndSet is the set of all OrSets.
-type AndSet []OrSet
-
-// Lines returns the lines to be printed.
-func (and AndSet) Lines() (ls []string) {
- for _, or := range and {
- ls = append(ls, or.Line())
- }
- return
-}
-
-// Join joins this AndSet with another.
-func (and AndSet) Join(other AndSet) AndSet {
- return append(and, other...)
-}
-
-// Tags returns the unique set of +build tags.
-//
-// Derived form the runtime's canBuild.
-func Tags(file string) (tags AndSet) {
- data, err := ioutil.ReadFile(file)
- if err != nil {
- return nil
- }
- // Check file contents for // +build lines.
- for _, p := range strings.Split(string(data), "\n") {
- p = strings.TrimSpace(p)
- if p == "" {
- continue
- }
- if !strings.HasPrefix(p, "//") {
- break
- }
- if !strings.Contains(p, "+build") {
- continue
- }
- fields := strings.Fields(p[2:])
- if len(fields) < 1 || fields[0] != "+build" {
- continue
- }
- tags = append(tags, OrSet(fields[1:]))
- }
- return tags
-}
-
-// Aggregate aggregates all tags from a set of files.
-//
-// Note that these may be in conflict, in which case the build will fail.
-func Aggregate(files []string) (tags AndSet) {
- for _, file := range files {
- tags = tags.Join(Tags(file))
- }
- return tags
-}