summaryrefslogtreecommitdiffhomepage
path: root/pkg
diff options
context:
space:
mode:
Diffstat (limited to 'pkg')
-rw-r--r--pkg/bitmap/BUILD16
-rw-r--r--pkg/bitmap/bitmap.go234
-rw-r--r--pkg/bitmap/bitmap_test.go306
-rw-r--r--pkg/gohacks/gohacks_unsafe.go8
-rw-r--r--pkg/procid/procid_amd64.s4
-rw-r--r--pkg/procid/procid_arm64.s4
-rw-r--r--pkg/sentry/control/BUILD4
-rw-r--r--pkg/sentry/control/fs.go93
-rw-r--r--pkg/sentry/control/lifecycle.go36
-rw-r--r--pkg/sentry/fsimpl/cgroupfs/cgroupfs.go5
-rw-r--r--pkg/sentry/fsimpl/gofer/gofer.go28
-rw-r--r--pkg/sentry/fsimpl/proc/task.go2
-rw-r--r--pkg/sentry/fsimpl/proc/task_files.go111
-rw-r--r--pkg/sentry/inet/inet.go3
-rw-r--r--pkg/sentry/inet/test_stack.go50
-rw-r--r--pkg/sentry/kernel/BUILD1
-rw-r--r--pkg/sentry/kernel/fd_table.go196
-rw-r--r--pkg/sentry/kernel/fd_table_unsafe.go11
-rw-r--r--pkg/sentry/kernel/msgqueue/msgqueue.go278
-rw-r--r--pkg/sentry/platform/kvm/bluepill_fault.go2
-rw-r--r--pkg/sentry/platform/kvm/bluepill_unsafe.go10
-rw-r--r--pkg/sentry/platform/kvm/machine.go13
-rw-r--r--pkg/sentry/platform/kvm/machine_unsafe.go8
-rw-r--r--pkg/sentry/platform/ptrace/subprocess_unsafe.go8
-rw-r--r--pkg/sentry/socket/hostinet/stack.go29
-rw-r--r--pkg/sentry/socket/netlink/route/protocol.go43
-rw-r--r--pkg/sentry/socket/netstack/stack.go6
-rw-r--r--pkg/sentry/syscalls/linux/BUILD1
-rw-r--r--pkg/sentry/syscalls/linux/linux64.go8
-rw-r--r--pkg/sentry/syscalls/linux/sys_msgqueue.go85
-rw-r--r--pkg/sentry/syscalls/linux/sys_prctl.go16
-rw-r--r--pkg/sync/goyield_unsafe.go8
-rw-r--r--pkg/sync/runtime_unsafe.go13
-rw-r--r--pkg/tcpip/link/fdbased/endpoint.go32
-rw-r--r--pkg/tcpip/link/fdbased/mmap.go21
-rw-r--r--pkg/tcpip/link/fdbased/mmap_unsafe.go9
-rw-r--r--pkg/tcpip/link/fdbased/packet_dispatchers.go53
-rw-r--r--pkg/tcpip/link/qdisc/fifo/endpoint.go7
-rw-r--r--pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go7
-rw-r--r--pkg/tcpip/link/rawfile/rawfile_unsafe.go55
-rw-r--r--pkg/tcpip/link/sniffer/pcap.go56
-rw-r--r--pkg/tcpip/link/sniffer/sniffer.go33
-rw-r--r--pkg/tcpip/network/ipv4/BUILD1
-rw-r--r--pkg/tcpip/network/ipv4/ipv4_test.go3
-rw-r--r--pkg/tcpip/stack/stack.go3
-rw-r--r--pkg/tcpip/stack/stack_test.go77
46 files changed, 1682 insertions, 315 deletions
diff --git a/pkg/bitmap/BUILD b/pkg/bitmap/BUILD
new file mode 100644
index 000000000..0f1e7006d
--- /dev/null
+++ b/pkg/bitmap/BUILD
@@ -0,0 +1,16 @@
+load("//tools:defs.bzl", "go_library", "go_test")
+
+package(licenses = ["notice"])
+
+go_library(
+ name = "bitmap",
+ srcs = ["bitmap.go"],
+ visibility = ["//:sandbox"],
+)
+
+go_test(
+ name = "bitmap_test",
+ size = "small",
+ srcs = ["bitmap_test.go"],
+ library = ":bitmap",
+)
diff --git a/pkg/bitmap/bitmap.go b/pkg/bitmap/bitmap.go
new file mode 100644
index 000000000..12d2fc2b8
--- /dev/null
+++ b/pkg/bitmap/bitmap.go
@@ -0,0 +1,234 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+// Package bitmap provides the implementation of bitmap.
+package bitmap
+
+import (
+ "math"
+ "math/bits"
+)
+
+// Bitmap implements an efficient bitmap.
+//
+// +stateify savable
+type Bitmap struct {
+ // numOnes is the number of ones in the bitmap.
+ numOnes uint32
+
+ // bitBlock holds the bits. The type of bitBlock is uint64 which means
+ // each number in bitBlock contains 64 entries.
+ bitBlock []uint64
+}
+
+// New create a new empty Bitmap.
+func New(size uint32) Bitmap {
+ b := Bitmap{}
+ bSize := (size + 63) / 64
+ b.bitBlock = make([]uint64, bSize)
+ return b
+}
+
+// IsEmpty verifies whether the Bitmap is empty.
+func (b *Bitmap) IsEmpty() bool {
+ return b.numOnes == 0
+}
+
+// Minimum return the smallest value in the Bitmap.
+func (b *Bitmap) Minimum() uint32 {
+ for i := 0; i < len(b.bitBlock); i++ {
+ if w := b.bitBlock[i]; w != 0 {
+ r := bits.TrailingZeros64(w)
+ return uint32(r + i*64)
+ }
+ }
+ return math.MaxInt32
+}
+
+// FirstZero returns the first unset bit from the range [start, ).
+func (b *Bitmap) FirstZero(start uint32) uint32 {
+ i, nbit := int(start/64), start%64
+ n := len(b.bitBlock)
+ if i >= n {
+ return math.MaxInt32
+ }
+ w := b.bitBlock[i] | ((1 << nbit) - 1)
+ for {
+ if w != ^uint64(0) {
+ r := bits.TrailingZeros64(^w)
+ return uint32(r + i*64)
+ }
+ i++
+ if i == n {
+ break
+ }
+ w = b.bitBlock[i]
+ }
+ return math.MaxInt32
+}
+
+// Maximum return the largest value in the Bitmap.
+func (b *Bitmap) Maximum() uint32 {
+ for i := len(b.bitBlock) - 1; i >= 0; i-- {
+ if w := b.bitBlock[i]; w != 0 {
+ r := bits.LeadingZeros64(w)
+ return uint32(i*64 + 63 - r)
+ }
+ }
+ return uint32(0)
+}
+
+// Add add i to the Bitmap.
+func (b *Bitmap) Add(i uint32) {
+ blockNum, mask := i/64, uint64(1)<<(i%64)
+ // if blockNum is out of range, extend b.bitBlock
+ if x, y := int(blockNum), len(b.bitBlock); x >= y {
+ b.bitBlock = append(b.bitBlock, make([]uint64, x-y+1)...)
+ }
+ oldBlock := b.bitBlock[blockNum]
+ newBlock := oldBlock | mask
+ if oldBlock != newBlock {
+ b.bitBlock[blockNum] = newBlock
+ b.numOnes++
+ }
+}
+
+// Remove i from the Bitmap.
+func (b *Bitmap) Remove(i uint32) {
+ blockNum, mask := i/64, uint64(1)<<(i%64)
+ oldBlock := b.bitBlock[blockNum]
+ newBlock := oldBlock &^ mask
+ if oldBlock != newBlock {
+ b.bitBlock[blockNum] = newBlock
+ b.numOnes--
+ }
+}
+
+// Clone the Bitmap.
+func (b *Bitmap) Clone() Bitmap {
+ bitmap := Bitmap{b.numOnes, make([]uint64, len(b.bitBlock))}
+ copy(bitmap.bitBlock, b.bitBlock[:])
+ return bitmap
+}
+
+// countOnesForBlocks count all 1 bits within b.bitBlock of begin and that of end.
+// The begin block and end block are inclusive.
+func (b *Bitmap) countOnesForBlocks(begin, end uint32) uint64 {
+ ones := uint64(0)
+ beginBlock := begin / 64
+ endBlock := end / 64
+ for i := beginBlock; i <= endBlock; i++ {
+ ones += uint64(bits.OnesCount64(b.bitBlock[i]))
+ }
+ return ones
+}
+
+// countOnesForAllBlocks count all 1 bits in b.bitBlock.
+func (b *Bitmap) countOnesForAllBlocks() uint64 {
+ ones := uint64(0)
+ for i := 0; i < len(b.bitBlock); i++ {
+ ones += uint64(bits.OnesCount64(b.bitBlock[i]))
+ }
+ return ones
+}
+
+// flipRange flip the bits within range (begin and end). begin is inclusive and end is exclusive.
+func (b *Bitmap) flipRange(begin, end uint32) {
+ end--
+ beginBlock := begin / 64
+ endBlock := end / 64
+ if beginBlock == endBlock {
+ b.bitBlock[endBlock] ^= ((^uint64(0) << uint(begin%64)) & ((uint64(1) << (uint(end)%64 + 1)) - 1))
+ } else {
+ b.bitBlock[beginBlock] ^= ^(^uint64(0) << uint(begin%64))
+ for i := beginBlock; i < endBlock; i++ {
+ b.bitBlock[i] = ^b.bitBlock[i]
+ }
+ b.bitBlock[endBlock] ^= ((uint64(1) << (uint(end)%64 + 1)) - 1)
+ }
+}
+
+// clearRange clear the bits within range (begin and end). begin is inclusive and end is exclusive.
+func (b *Bitmap) clearRange(begin, end uint32) {
+ end--
+ beginBlock := begin / 64
+ endBlock := end / 64
+ if beginBlock == endBlock {
+ b.bitBlock[beginBlock] &= (((uint64(1) << uint(begin%64)) - 1) | ^((uint64(1) << (uint(end)%64 + 1)) - 1))
+ } else {
+ b.bitBlock[beginBlock] &= ((uint64(1) << uint(begin%64)) - 1)
+ for i := beginBlock + 1; i < endBlock; i++ {
+ b.bitBlock[i] &= ^b.bitBlock[i]
+ }
+ b.bitBlock[endBlock] &= ^((uint64(1) << (uint(end)%64 + 1)) - 1)
+ }
+}
+
+// ClearRange clear bits within range (begin and end) for the Bitmap. begin is inclusive and end is exclusive.
+func (b *Bitmap) ClearRange(begin, end uint32) {
+ blockRange := end/64 - begin/64
+ // When the number of cleared blocks is larger than half of the length of b.bitBlock,
+ // counting 1s for the entire bitmap has better performance.
+ if blockRange > uint32(len(b.bitBlock)/2) {
+ b.clearRange(begin, end)
+ b.numOnes = uint32(b.countOnesForAllBlocks())
+ } else {
+ oldRangeOnes := b.countOnesForBlocks(begin, end)
+ b.clearRange(begin, end)
+ newRangeOnes := b.countOnesForBlocks(begin, end)
+ b.numOnes += uint32(newRangeOnes - oldRangeOnes)
+ }
+}
+
+// FlipRange flip bits within range (begin and end) for the Bitmap. begin is inclusive and end is exclusive.
+func (b *Bitmap) FlipRange(begin, end uint32) {
+ blockRange := end/64 - begin/64
+ // When the number of flipped blocks is larger than half of the length of b.bitBlock,
+ // counting 1s for the entire bitmap has better performance.
+ if blockRange > uint32(len(b.bitBlock)/2) {
+ b.flipRange(begin, end)
+ b.numOnes = uint32(b.countOnesForAllBlocks())
+ } else {
+ oldRangeOnes := b.countOnesForBlocks(begin, end)
+ b.flipRange(begin, end)
+ newRangeOnes := b.countOnesForBlocks(begin, end)
+ b.numOnes += uint32(newRangeOnes - oldRangeOnes)
+ }
+}
+
+// ToSlice transform the Bitmap into slice. For example, a bitmap of [0, 1, 0, 1]
+// will return the slice [1, 3].
+func (b *Bitmap) ToSlice() []uint32 {
+ bitmapSlice := make([]uint32, 0, b.numOnes)
+ // base is the start number of a bitBlock
+ base := 0
+ for i := 0; i < len(b.bitBlock); i++ {
+ bitBlock := b.bitBlock[i]
+ // Iterate through all the numbers held by this bit block.
+ for bitBlock != 0 {
+ // Extract the lowest set 1 bit.
+ j := bitBlock & -bitBlock
+ // Interpret the bit as the in32 number it represents and add it to result.
+ bitmapSlice = append(bitmapSlice, uint32((base + int(bits.OnesCount64(j-1)))))
+ bitBlock ^= j
+ }
+ base += 64
+ }
+ return bitmapSlice
+}
+
+// GetNumOnes return the the number of ones in the Bitmap.
+func (b *Bitmap) GetNumOnes() uint32 {
+ return b.numOnes
+}
diff --git a/pkg/bitmap/bitmap_test.go b/pkg/bitmap/bitmap_test.go
new file mode 100644
index 000000000..76ebd779f
--- /dev/null
+++ b/pkg/bitmap/bitmap_test.go
@@ -0,0 +1,306 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package bitmap
+
+import (
+ "math"
+ "reflect"
+ "testing"
+)
+
+// generateFilledSlice generates a slice fill with numbers.
+func generateFilledSlice(min, max, length int) []uint32 {
+ if max == min {
+ return []uint32{uint32(min)}
+ }
+ if length > (max - min) {
+ return nil
+ }
+ randSlice := make([]uint32, length)
+ if length != 0 {
+ rangeNum := uint32((max - min) / length)
+ randSlice[0], randSlice[length-1] = uint32(min), uint32(max)
+ for i := 1; i < length-1; i++ {
+ randSlice[i] = randSlice[i-1] + rangeNum
+ }
+ }
+ return randSlice
+}
+
+// generateFilledBitmap generates a Bitmap filled with fillNum of numbers,
+// and returns the slice and bitmap.
+func generateFilledBitmap(min, max, fillNum int) ([]uint32, Bitmap) {
+ bitmap := New(uint32(max))
+ randSlice := generateFilledSlice(min, max, fillNum)
+ for i := 0; i < fillNum; i++ {
+ bitmap.Add(randSlice[i])
+ }
+ return randSlice, bitmap
+}
+
+func TestNewBitmap(t *testing.T) {
+ tests := []struct {
+ name string
+ size int
+ expectSize int
+ }{
+ {"length 1", 1, 1},
+ {"length 1024", 1024, 16},
+ {"length 1025", 1025, 17},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ if bitmap := New(uint32(tt.size)); len(bitmap.bitBlock) != tt.expectSize {
+ t.Errorf("New created bitmap with %v, bitBlock size: %d, wanted: %d", tt.name, len(bitmap.bitBlock), tt.expectSize)
+ }
+ })
+ }
+}
+
+func TestAdd(t *testing.T) {
+ tests := []struct {
+ name string
+ bitmapSize int
+ addNum int
+ }{
+ {"Add with null bitmap.bitBlock", 0, 10},
+ {"Add without extending bitBlock", 64, 10},
+ {"Add without extending bitblock with margin number", 63, 64},
+ {"Add with extended one block", 1024, 1025},
+ {"Add with extended more then one block", 1024, 2048},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ bitmap := New(uint32(tt.bitmapSize))
+ bitmap.Add(uint32(tt.addNum))
+ bitmapSlice := bitmap.ToSlice()
+ if bitmapSlice[0] != uint32(tt.addNum) {
+ t.Errorf("%v, get number: %d, wanted: %d.", tt.name, bitmapSlice[0], tt.addNum)
+ }
+ })
+ }
+}
+
+func TestRemove(t *testing.T) {
+ bitmap := New(uint32(1024))
+ firstSlice := generateFilledSlice(0, 511, 50)
+ secondSlice := generateFilledSlice(512, 1024, 50)
+ for i := 0; i < 50; i++ {
+ bitmap.Add(firstSlice[i])
+ bitmap.Add(secondSlice[i])
+ }
+
+ for i := 0; i < 50; i++ {
+ bitmap.Remove(firstSlice[i])
+ }
+ bitmapSlice := bitmap.ToSlice()
+ if !reflect.DeepEqual(bitmapSlice, secondSlice) {
+ t.Errorf("After Remove() firstSlice, remained slice: %v, wanted: %v", bitmapSlice, secondSlice)
+ }
+
+ for i := 0; i < 50; i++ {
+ bitmap.Remove(secondSlice[i])
+ }
+ bitmapSlice = bitmap.ToSlice()
+ emptySlice := make([]uint32, 0)
+ if !reflect.DeepEqual(bitmapSlice, emptySlice) {
+ t.Errorf("After Remove secondSlice, remained slice: %v, wanted: %v", bitmapSlice, emptySlice)
+ }
+
+}
+
+// Verify flip bits within one bitBlock, one bit and bits cross multi bitBlocks.
+func TestFlipRange(t *testing.T) {
+ tests := []struct {
+ name string
+ flipRangeMin int
+ flipRangeMax int
+ filledSliceLen int
+ }{
+ {"Flip one number in bitmap", 77, 77, 1},
+ {"Flip numbers within one bitBlocks", 5, 60, 20},
+ {"Flip numbers that cross multi bitBlocks", 20, 1000, 300},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ fillSlice, bitmap := generateFilledBitmap(tt.flipRangeMin, tt.flipRangeMax, tt.filledSliceLen)
+ flipFillSlice := make([]uint32, 0)
+ for i, j := tt.flipRangeMin, 0; i <= tt.flipRangeMax; i++ {
+ if uint32(i) != fillSlice[j] {
+ flipFillSlice = append(flipFillSlice, uint32(i))
+ } else {
+ j++
+ }
+ }
+
+ bitmap.FlipRange(uint32(tt.flipRangeMin), uint32(tt.flipRangeMax+1))
+ flipBitmapSlice := bitmap.ToSlice()
+ if !reflect.DeepEqual(flipFillSlice, flipBitmapSlice) {
+ t.Errorf("%v, flipped slice: %v, wanted: %v", tt.name, flipBitmapSlice, flipFillSlice)
+ }
+ })
+ }
+}
+
+// Verify clear bits within one bitBlock, one bit and bits cross multi bitBlocks.
+func TestClearRange(t *testing.T) {
+ tests := []struct {
+ name string
+ clearRangeMin int
+ clearRangeMax int
+ bitmapSize int
+ }{
+ {"ClearRange clear one number", 5, 5, 64},
+ {"ClearRange clear numbers within one bitBlock", 4, 61, 64},
+ {"ClearRange clear numbers cross multi bitBlocks", 20, 254, 512},
+ }
+
+ for _, tt := range tests {
+ tt := tt
+ t.Run(tt.name, func(t *testing.T) {
+ bitmap := New(uint32(tt.bitmapSize))
+ bitmap.FlipRange(uint32(0), uint32(tt.bitmapSize))
+ bitmap.ClearRange(uint32(tt.clearRangeMin), uint32(tt.clearRangeMax+1))
+ clearedBitmapSlice := bitmap.ToSlice()
+ clearedSlice := make([]uint32, 0)
+ for i := 0; i < tt.bitmapSize; i++ {
+ if i < tt.clearRangeMin || i > tt.clearRangeMax {
+ clearedSlice = append(clearedSlice, uint32(i))
+ }
+ }
+ if !reflect.DeepEqual(clearedSlice, clearedBitmapSlice) {
+ t.Errorf("%v, cleared slice: %v, wanted: %v", tt.name, clearedBitmapSlice, clearedSlice)
+ }
+ })
+
+ }
+}
+
+func TestMinimum(t *testing.T) {
+ randSlice, bitmap := generateFilledBitmap(0, 1024, 200)
+ min := bitmap.Minimum()
+ if min != randSlice[0] {
+ t.Errorf("Minimum() returns: %v, wanted: %v", min, randSlice[0])
+ }
+
+ bitmap.ClearRange(uint32(0), uint32(200))
+ min = bitmap.Minimum()
+ bitmapSlice := bitmap.ToSlice()
+ if min != bitmapSlice[0] {
+ t.Errorf("After ClearRange, Minimum() returns: %v, wanted: %v", min, bitmapSlice[0])
+ }
+
+ bitmap.FlipRange(uint32(2), uint32(11))
+ min = bitmap.Minimum()
+ bitmapSlice = bitmap.ToSlice()
+ if min != bitmapSlice[0] {
+ t.Errorf("After Flip, Minimum() returns: %v, wanted: %v", min, bitmapSlice[0])
+ }
+}
+
+func TestMaximum(t *testing.T) {
+ randSlice, bitmap := generateFilledBitmap(0, 1024, 200)
+ max := bitmap.Maximum()
+ if max != randSlice[len(randSlice)-1] {
+ t.Errorf("Maximum() returns: %v, wanted: %v", max, randSlice[len(randSlice)-1])
+ }
+
+ bitmap.ClearRange(uint32(1000), uint32(1025))
+ max = bitmap.Maximum()
+ bitmapSlice := bitmap.ToSlice()
+ if max != bitmapSlice[len(bitmapSlice)-1] {
+ t.Errorf("After ClearRange, Maximum() returns: %v, wanted: %v", max, bitmapSlice[len(bitmapSlice)-1])
+ }
+
+ bitmap.FlipRange(uint32(1001), uint32(1021))
+ max = bitmap.Maximum()
+ bitmapSlice = bitmap.ToSlice()
+ if max != bitmapSlice[len(bitmapSlice)-1] {
+ t.Errorf("After Flip, Maximum() returns: %v, wanted: %v", max, bitmapSlice[len(bitmapSlice)-1])
+ }
+}
+
+func TestBitmapNumOnes(t *testing.T) {
+ randSlice, bitmap := generateFilledBitmap(0, 1024, 200)
+ bitmapOnes := bitmap.GetNumOnes()
+ if bitmapOnes != uint32(200) {
+ t.Errorf("GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 200)
+ }
+ // Remove 10 numbers.
+ for i := 5; i < 15; i++ {
+ bitmap.Remove(randSlice[i])
+ }
+ bitmapOnes = bitmap.GetNumOnes()
+ if bitmapOnes != uint32(190) {
+ t.Errorf("After Remove 10 number, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 190)
+ }
+ // Remove the 10 number again, the length supposed not change.
+ for i := 5; i < 15; i++ {
+ bitmap.Remove(randSlice[i])
+ }
+ bitmapOnes = bitmap.GetNumOnes()
+ if bitmapOnes != uint32(190) {
+ t.Errorf("After Remove the 10 number again, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 190)
+ }
+
+ // Add 10 number.
+ for i := 1080; i < 1090; i++ {
+ bitmap.Add(uint32(i))
+ }
+ bitmapOnes = bitmap.GetNumOnes()
+ if bitmapOnes != uint32(200) {
+ t.Errorf("After Add 10 number, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 200)
+ }
+
+ // Add the 10 number again, the length supposed not change.
+ for i := 1080; i < 1090; i++ {
+ bitmap.Add(uint32(i))
+ }
+ bitmapOnes = bitmap.GetNumOnes()
+ if bitmapOnes != uint32(200) {
+ t.Errorf("After Add the 10 number again, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 200)
+ }
+
+ // Flip 10 bits from 0 to 1.
+ bitmap.FlipRange(uint32(1025), uint32(1035))
+ bitmapOnes = bitmap.GetNumOnes()
+ if bitmapOnes != uint32(210) {
+ t.Errorf("After Flip, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 210)
+ }
+
+ // ClearRange numbers range from [0, 1025).
+ bitmap.ClearRange(uint32(0), uint32(1025))
+ bitmapOnes = bitmap.GetNumOnes()
+ if bitmapOnes != uint32(20) {
+ t.Errorf("After ClearRange, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 20)
+ }
+}
+
+func TestFirstZero(t *testing.T) {
+ bitmap := New(uint32(1000))
+ bitmap.FlipRange(200, 400)
+ for i, j := range map[uint32]uint32{0: 0, 201: 400, 200: 400, 199: 199, 400: 400, 10000: math.MaxInt32} {
+ v := bitmap.FirstZero(i)
+ if v != j {
+ t.Errorf("Minimum() returns: %v, wanted: %v", v, j)
+ }
+ }
+}
diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go
index 09fc14787..bd8ceba19 100644
--- a/pkg/gohacks/gohacks_unsafe.go
+++ b/pkg/gohacks/gohacks_unsafe.go
@@ -15,7 +15,13 @@
//go:build go1.13 && !go1.18
// +build go1.13,!go1.18
-// Check type signatures when updating Go version.
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
+
+// Check type signatures and Noescape when updating Go version.
+//
+// TODO(b/165820485): add these checks to checklinkname.
// Package gohacks contains utilities for subverting the Go compiler.
package gohacks
diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s
index b5bbfff90..74a8de42c 100644
--- a/pkg/procid/procid_amd64.s
+++ b/pkg/procid/procid_amd64.s
@@ -15,6 +15,10 @@
//go:build amd64 && go1.8 && !go1.18 && go1.1
// +build amd64,go1.8,!go1.18,go1.1
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
+
#include "textflag.h"
TEXT ·Current(SB),NOSPLIT,$0-8
diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s
index 772d96289..48182c4a9 100644
--- a/pkg/procid/procid_arm64.s
+++ b/pkg/procid/procid_arm64.s
@@ -15,6 +15,10 @@
//go:build arm64 && go1.8 && !go1.18 && go1.1
// +build arm64,go1.8,!go1.18,go1.1
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
+
#include "textflag.h"
TEXT ·Current(SB),NOSPLIT,$0-8
diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD
index deaf5fa23..7ee237c9f 100644
--- a/pkg/sentry/control/BUILD
+++ b/pkg/sentry/control/BUILD
@@ -6,6 +6,8 @@ go_library(
name = "control",
srcs = [
"control.go",
+ "fs.go",
+ "lifecycle.go",
"logging.go",
"pprof.go",
"proc.go",
@@ -16,6 +18,7 @@ go_library(
],
deps = [
"//pkg/abi/linux",
+ "//pkg/context",
"//pkg/fd",
"//pkg/log",
"//pkg/sentry/fdimport",
@@ -35,6 +38,7 @@ go_library(
"//pkg/sync",
"//pkg/tcpip/link/sniffer",
"//pkg/urpc",
+ "//pkg/usermem",
],
)
diff --git a/pkg/sentry/control/fs.go b/pkg/sentry/control/fs.go
new file mode 100644
index 000000000..d19b21f2d
--- /dev/null
+++ b/pkg/sentry/control/fs.go
@@ -0,0 +1,93 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "fmt"
+ "io"
+ "os"
+
+ "gvisor.dev/gvisor/pkg/context"
+ "gvisor.dev/gvisor/pkg/sentry/fs"
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+ "gvisor.dev/gvisor/pkg/urpc"
+ "gvisor.dev/gvisor/pkg/usermem"
+)
+
+// CatOpts contains options for the Cat RPC call.
+type CatOpts struct {
+ // Files are the filesystem paths for the files to cat.
+ Files []string `json:"files"`
+
+ // FilePayload contains the destination for output.
+ urpc.FilePayload
+}
+
+// Fs includes fs-related functions.
+type Fs struct {
+ Kernel *kernel.Kernel
+}
+
+// Cat is a RPC stub which prints out and returns the content of the files.
+func (f *Fs) Cat(o *CatOpts, _ *struct{}) error {
+ // Create an output stream.
+ if len(o.FilePayload.Files) != 1 {
+ return ErrInvalidFiles
+ }
+
+ output := o.FilePayload.Files[0]
+ for _, file := range o.Files {
+ if err := cat(f.Kernel, file, output); err != nil {
+ return fmt.Errorf("cannot read from file %s: %v", file, err)
+ }
+ }
+
+ return nil
+}
+
+// fileReader encapsulates a fs.File and provides an io.Reader interface.
+type fileReader struct {
+ ctx context.Context
+ file *fs.File
+}
+
+// Read implements io.Reader.Read.
+func (f *fileReader) Read(p []byte) (int, error) {
+ n, err := f.file.Readv(f.ctx, usermem.BytesIOSequence(p))
+ return int(n), err
+}
+
+func cat(k *kernel.Kernel, path string, output *os.File) error {
+ ctx := k.SupervisorContext()
+ mns := k.GlobalInit().Leader().MountNamespace()
+ root := mns.Root()
+ defer root.DecRef(ctx)
+
+ remainingTraversals := uint(fs.DefaultTraversalLimit)
+ d, err := mns.FindInode(ctx, root, nil, path, &remainingTraversals)
+ if err != nil {
+ return fmt.Errorf("cannot find file %s: %v", path, err)
+ }
+ defer d.DecRef(ctx)
+
+ file, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true})
+ if err != nil {
+ return fmt.Errorf("cannot get file for path %s: %v", path, err)
+ }
+ defer file.DecRef(ctx)
+
+ _, err = io.Copy(output, &fileReader{ctx: ctx, file: file})
+ return err
+}
diff --git a/pkg/sentry/control/lifecycle.go b/pkg/sentry/control/lifecycle.go
new file mode 100644
index 000000000..67abf497d
--- /dev/null
+++ b/pkg/sentry/control/lifecycle.go
@@ -0,0 +1,36 @@
+// Copyright 2021 The gVisor Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+
+package control
+
+import (
+ "gvisor.dev/gvisor/pkg/sentry/kernel"
+)
+
+// Lifecycle provides functions related to starting and stopping tasks.
+type Lifecycle struct {
+ Kernel *kernel.Kernel
+}
+
+// Pause pauses all tasks, blocking until they are stopped.
+func (l *Lifecycle) Pause(_, _ *struct{}) error {
+ l.Kernel.Pause()
+ return nil
+}
+
+// Resume resumes all tasks.
+func (l *Lifecycle) Resume(_, _ *struct{}) error {
+ l.Kernel.Unpause()
+ return nil
+}
diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
index 24e28a51f..22c8b7fda 100644
--- a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
+++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go
@@ -383,11 +383,6 @@ func (d *dir) DecRef(ctx context.Context) {
d.dirRefs.DecRef(func() { d.Destroy(ctx) })
}
-// StatFS implements kernfs.Inode.StatFS.
-func (d *dir) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) {
- return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil
-}
-
// controllerFile represents a generic control file that appears within a cgroup
// directory.
//
diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go
index ec8d58cc9..25d2e39d6 100644
--- a/pkg/sentry/fsimpl/gofer/gofer.go
+++ b/pkg/sentry/fsimpl/gofer/gofer.go
@@ -1161,6 +1161,13 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
if !d.isSynthetic() {
if stat.Mask != 0 {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ // d.dataMu must be held around the update to both the remote
+ // file's size and d.size to serialize with writeback (which
+ // might otherwise write data back up to the old d.size after
+ // the remote file has been truncated).
+ d.dataMu.Lock()
+ }
if err := d.file.setAttr(ctx, p9.SetAttrMask{
Permissions: stat.Mask&linux.STATX_MODE != 0,
UID: stat.Mask&linux.STATX_UID != 0,
@@ -1180,13 +1187,16 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs
MTimeSeconds: uint64(stat.Mtime.Sec),
MTimeNanoSeconds: uint64(stat.Mtime.Nsec),
}); err != nil {
+ if stat.Mask&linux.STATX_SIZE != 0 {
+ d.dataMu.Unlock() // +checklocksforce: locked conditionally above
+ }
return err
}
if stat.Mask&linux.STATX_SIZE != 0 {
// d.size should be kept up to date, and privatized
// copy-on-write mappings of truncated pages need to be
// invalidated, even if InteropModeShared is in effect.
- d.updateSizeLocked(stat.Size)
+ d.updateSizeAndUnlockDataMuLocked(stat.Size) // +checklocksforce: locked conditionally above
}
}
if d.fs.opts.interop == InteropModeShared {
@@ -1249,6 +1259,14 @@ func (d *dentry) doAllocate(ctx context.Context, offset, length uint64, allocate
// Preconditions: d.metadataMu must be locked.
func (d *dentry) updateSizeLocked(newSize uint64) {
d.dataMu.Lock()
+ d.updateSizeAndUnlockDataMuLocked(newSize)
+}
+
+// Preconditions: d.metadataMu and d.dataMu must be locked.
+//
+// Postconditions: d.dataMu is unlocked.
+// +checklocksrelease:d.dataMu
+func (d *dentry) updateSizeAndUnlockDataMuLocked(newSize uint64) {
oldSize := d.size
atomic.StoreUint64(&d.size, newSize)
// d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings
@@ -1257,9 +1275,9 @@ func (d *dentry) updateSizeLocked(newSize uint64) {
// contents beyond the new d.size. (We are still holding d.metadataMu,
// so we can't race with Write or another truncate.)
d.dataMu.Unlock()
- if d.size < oldSize {
+ if newSize < oldSize {
oldpgend, _ := hostarch.PageRoundUp(oldSize)
- newpgend, _ := hostarch.PageRoundUp(d.size)
+ newpgend, _ := hostarch.PageRoundUp(newSize)
if oldpgend != newpgend {
d.mapsMu.Lock()
d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{
@@ -1275,8 +1293,8 @@ func (d *dentry) updateSizeLocked(newSize uint64) {
// truncated pages have been removed from the remote file, they
// should be dropped without being written back.
d.dataMu.Lock()
- d.cache.Truncate(d.size, d.fs.mfp.MemoryFile())
- d.dirty.KeepClean(memmap.MappableRange{d.size, oldpgend})
+ d.cache.Truncate(newSize, d.fs.mfp.MemoryFile())
+ d.dirty.KeepClean(memmap.MappableRange{newSize, oldpgend})
d.dataMu.Unlock()
}
}
diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go
index cbbc0935a..f54811edf 100644
--- a/pkg/sentry/fsimpl/proc/task.go
+++ b/pkg/sentry/fsimpl/proc/task.go
@@ -78,7 +78,7 @@ func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns
"smaps": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &smapsData{task: task}),
"stat": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}),
"statm": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &statmData{task: task}),
- "status": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &statusData{task: task, pidns: pidns}),
+ "status": fs.newStatusInode(ctx, task, pidns, fs.NextIno(), 0444),
"uid_map": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}),
}
if isThreadGroup {
diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go
index 5bb6bc372..0ce3ed797 100644
--- a/pkg/sentry/fsimpl/proc/task_files.go
+++ b/pkg/sentry/fsimpl/proc/task_files.go
@@ -661,34 +661,119 @@ func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error {
return nil
}
-// statusData implements vfs.DynamicBytesSource for /proc/[pid]/status.
+// statusInode implements kernfs.Inode for /proc/[pid]/status.
//
// +stateify savable
-type statusData struct {
- kernfs.DynamicBytesFile
+type statusInode struct {
+ kernfs.InodeAttrs
+ kernfs.InodeNoStatFS
+ kernfs.InodeNoopRefCount
+ kernfs.InodeNotDirectory
+ kernfs.InodeNotSymlink
task *kernel.Task
pidns *kernel.PIDNamespace
+ locks vfs.FileLocks
}
-var _ dynamicInode = (*statusData)(nil)
+// statusFD implements vfs.FileDescriptionImpl and vfs.DynamicByteSource for
+// /proc/[pid]/status.
+//
+// +stateify savable
+type statusFD struct {
+ statusFDLowerBase
+ vfs.DynamicBytesFileDescriptionImpl
+ vfs.LockFD
+
+ vfsfd vfs.FileDescription
+
+ inode *statusInode
+ task *kernel.Task
+ pidns *kernel.PIDNamespace
+ userns *auth.UserNamespace // equivalent to struct file::f_cred::user_ns
+}
+
+// statusFDLowerBase is a dumb hack to ensure that statusFD prefers
+// vfs.DynamicBytesFileDescriptionImpl methods to vfs.FileDescriptinDefaultImpl
+// methods.
+//
+// +stateify savable
+type statusFDLowerBase struct {
+ vfs.FileDescriptionDefaultImpl
+}
+
+func (fs *filesystem) newStatusInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, ino uint64, perm linux.FileMode) kernfs.Inode {
+ // Note: credentials are overridden by taskOwnedInode.
+ inode := &statusInode{
+ task: task,
+ pidns: pidns,
+ }
+ inode.InodeAttrs.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeRegular|perm)
+ return &taskOwnedInode{Inode: inode, owner: task}
+}
+
+// Open implements kernfs.Inode.Open.
+func (s *statusInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) {
+ fd := &statusFD{
+ inode: s,
+ task: s.task,
+ pidns: s.pidns,
+ userns: rp.Credentials().UserNamespace,
+ }
+ fd.LockFD.Init(&s.locks)
+ if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil {
+ return nil, err
+ }
+ fd.SetDataSource(fd)
+ return &fd.vfsfd, nil
+}
+
+// SetStat implements kernfs.Inode.SetStat.
+func (*statusInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error {
+ return linuxerr.EPERM
+}
+
+// Release implements vfs.FileDescriptionImpl.Release.
+func (s *statusFD) Release(ctx context.Context) {
+}
+
+// Stat implements vfs.FileDescriptionImpl.Stat.
+func (s *statusFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) {
+ fs := s.vfsfd.VirtualDentry().Mount().Filesystem()
+ return s.inode.Stat(ctx, fs, opts)
+}
+
+// SetStat implements vfs.FileDescriptionImpl.SetStat.
+func (s *statusFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error {
+ return linuxerr.EPERM
+}
// Generate implements vfs.DynamicBytesSource.Generate.
-func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
+func (s *statusFD) Generate(ctx context.Context, buf *bytes.Buffer) error {
fmt.Fprintf(buf, "Name:\t%s\n", s.task.Name())
fmt.Fprintf(buf, "State:\t%s\n", s.task.StateStatus())
fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.task.ThreadGroup()))
fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.task))
+
ppid := kernel.ThreadID(0)
if parent := s.task.Parent(); parent != nil {
ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup())
}
fmt.Fprintf(buf, "PPid:\t%d\n", ppid)
+
tpid := kernel.ThreadID(0)
if tracer := s.task.Tracer(); tracer != nil {
tpid = s.pidns.IDOfTask(tracer)
}
fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid)
+
+ creds := s.task.Credentials()
+ ruid := creds.RealKUID.In(s.userns).OrOverflow()
+ euid := creds.EffectiveKUID.In(s.userns).OrOverflow()
+ suid := creds.SavedKUID.In(s.userns).OrOverflow()
+ rgid := creds.RealKGID.In(s.userns).OrOverflow()
+ egid := creds.EffectiveKGID.In(s.userns).OrOverflow()
+ sgid := creds.SavedKGID.In(s.userns).OrOverflow()
var fds int
var vss, rss, data uint64
s.task.WithMuLocked(func(t *kernel.Task) {
@@ -701,12 +786,26 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error {
data = mm.VirtualDataSize()
}
})
+ // Filesystem user/group IDs aren't implemented; effective UID/GID are used
+ // instead.
+ fmt.Fprintf(buf, "Uid:\t%d\t%d\t%d\t%d\n", ruid, euid, suid, euid)
+ fmt.Fprintf(buf, "Gid:\t%d\t%d\t%d\t%d\n", rgid, egid, sgid, egid)
fmt.Fprintf(buf, "FDSize:\t%d\n", fds)
+ buf.WriteString("Groups:\t ")
+ // There is a space between each pair of supplemental GIDs, as well as an
+ // unconditional trailing space that some applications actually depend on.
+ var sep string
+ for _, kgid := range creds.ExtraKGIDs {
+ fmt.Fprintf(buf, "%s%d", sep, kgid.In(s.userns).OrOverflow())
+ sep = " "
+ }
+ buf.WriteString(" \n")
+
fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10)
fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10)
fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10)
+
fmt.Fprintf(buf, "Threads:\t%d\n", s.task.ThreadGroup().Count())
- creds := s.task.Credentials()
fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps)
fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps)
fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps)
diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go
index 80dda1559..b121fc1b4 100644
--- a/pkg/sentry/inet/inet.go
+++ b/pkg/sentry/inet/inet.go
@@ -27,6 +27,9 @@ type Stack interface {
// integers.
Interfaces() map[int32]Interface
+ // RemoveInterface removes the specified network interface.
+ RemoveInterface(idx int32) error
+
// InterfaceAddrs returns all network interface addresses as a mapping from
// interface indexes to a slice of associated interface address properties.
InterfaceAddrs() map[int32][]InterfaceAddr
diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go
index 218d9dafc..621f47e1f 100644
--- a/pkg/sentry/inet/test_stack.go
+++ b/pkg/sentry/inet/test_stack.go
@@ -45,23 +45,29 @@ func NewTestStack() *TestStack {
}
}
-// Interfaces implements Stack.Interfaces.
+// Interfaces implements Stack.
func (s *TestStack) Interfaces() map[int32]Interface {
return s.InterfacesMap
}
-// InterfaceAddrs implements Stack.InterfaceAddrs.
+// RemoveInterface implements Stack.
+func (s *TestStack) RemoveInterface(idx int32) error {
+ delete(s.InterfacesMap, idx)
+ return nil
+}
+
+// InterfaceAddrs implements Stack.
func (s *TestStack) InterfaceAddrs() map[int32][]InterfaceAddr {
return s.InterfaceAddrsMap
}
-// AddInterfaceAddr implements Stack.AddInterfaceAddr.
+// AddInterfaceAddr implements Stack.
func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error {
s.InterfaceAddrsMap[idx] = append(s.InterfaceAddrsMap[idx], addr)
return nil
}
-// RemoveInterfaceAddr implements Stack.RemoveInterfaceAddr.
+// RemoveInterfaceAddr implements Stack.
func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error {
interfaceAddrs, ok := s.InterfaceAddrsMap[idx]
if !ok {
@@ -79,94 +85,94 @@ func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error {
return nil
}
-// SupportsIPv6 implements Stack.SupportsIPv6.
+// SupportsIPv6 implements Stack.
func (s *TestStack) SupportsIPv6() bool {
return s.SupportsIPv6Flag
}
-// TCPReceiveBufferSize implements Stack.TCPReceiveBufferSize.
+// TCPReceiveBufferSize implements Stack.
func (s *TestStack) TCPReceiveBufferSize() (TCPBufferSize, error) {
return s.TCPRecvBufSize, nil
}
-// SetTCPReceiveBufferSize implements Stack.SetTCPReceiveBufferSize.
+// SetTCPReceiveBufferSize implements Stack.
func (s *TestStack) SetTCPReceiveBufferSize(size TCPBufferSize) error {
s.TCPRecvBufSize = size
return nil
}
-// TCPSendBufferSize implements Stack.TCPSendBufferSize.
+// TCPSendBufferSize implements Stack.
func (s *TestStack) TCPSendBufferSize() (TCPBufferSize, error) {
return s.TCPSendBufSize, nil
}
-// SetTCPSendBufferSize implements Stack.SetTCPSendBufferSize.
+// SetTCPSendBufferSize implements Stack.
func (s *TestStack) SetTCPSendBufferSize(size TCPBufferSize) error {
s.TCPSendBufSize = size
return nil
}
-// TCPSACKEnabled implements Stack.TCPSACKEnabled.
+// TCPSACKEnabled implements Stack.
func (s *TestStack) TCPSACKEnabled() (bool, error) {
return s.TCPSACKFlag, nil
}
-// SetTCPSACKEnabled implements Stack.SetTCPSACKEnabled.
+// SetTCPSACKEnabled implements Stack.
func (s *TestStack) SetTCPSACKEnabled(enabled bool) error {
s.TCPSACKFlag = enabled
return nil
}
-// TCPRecovery implements Stack.TCPRecovery.
+// TCPRecovery implements Stack.
func (s *TestStack) TCPRecovery() (TCPLossRecovery, error) {
return s.Recovery, nil
}
-// SetTCPRecovery implements Stack.SetTCPRecovery.
+// SetTCPRecovery implements Stack.
func (s *TestStack) SetTCPRecovery(recovery TCPLossRecovery) error {
s.Recovery = recovery
return nil
}
-// Statistics implements inet.Stack.Statistics.
+// Statistics implements Stack.
func (s *TestStack) Statistics(stat interface{}, arg string) error {
return nil
}
-// RouteTable implements Stack.RouteTable.
+// RouteTable implements Stack.
func (s *TestStack) RouteTable() []Route {
return s.RouteList
}
-// Resume implements Stack.Resume.
+// Resume implements Stack.
func (s *TestStack) Resume() {}
-// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
+// RegisteredEndpoints implements Stack.
func (s *TestStack) RegisteredEndpoints() []stack.TransportEndpoint {
return nil
}
-// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
+// CleanupEndpoints implements Stack.
func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint {
return nil
}
-// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
+// RestoreCleanupEndpoints implements Stack.
func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
-// SetForwarding implements inet.Stack.SetForwarding.
+// SetForwarding implements Stack.
func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error {
s.IPForwarding = enable
return nil
}
-// PortRange implements inet.Stack.PortRange.
+// PortRange implements Stack.
func (*TestStack) PortRange() (uint16, uint16) {
// Use the default Linux values per net/ipv4/af_inet.c:inet_init_net().
return 32768, 28232
}
-// SetPortRange implements inet.Stack.SetPortRange.
+// SetPortRange implements Stack.
func (*TestStack) SetPortRange(start uint16, end uint16) error {
// No-op.
return nil
diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD
index c613f4932..e4e0dc04f 100644
--- a/pkg/sentry/kernel/BUILD
+++ b/pkg/sentry/kernel/BUILD
@@ -220,6 +220,7 @@ go_library(
"//pkg/abi/linux",
"//pkg/abi/linux/errno",
"//pkg/amutex",
+ "//pkg/bitmap",
"//pkg/bits",
"//pkg/bpf",
"//pkg/cleanup",
diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go
index 8786a70b5..eff556a0c 100644
--- a/pkg/sentry/kernel/fd_table.go
+++ b/pkg/sentry/kernel/fd_table.go
@@ -18,10 +18,10 @@ import (
"fmt"
"math"
"strings"
- "sync/atomic"
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/abi/linux"
+ "gvisor.dev/gvisor/pkg/bitmap"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
"gvisor.dev/gvisor/pkg/sentry/fs"
@@ -84,13 +84,8 @@ type FDTable struct {
// mu protects below.
mu sync.Mutex `state:"nosave"`
- // next is start position to find fd.
- next int32
-
- // used contains the number of non-nil entries. It must be accessed
- // atomically. It may be read atomically without holding mu (but not
- // written).
- used int32
+ // fdBitmap shows which fds are already in use.
+ fdBitmap bitmap.Bitmap `state:"nosave"`
// descriptorTable holds descriptors.
descriptorTable `state:".(map[int32]descriptor)"`
@@ -98,6 +93,8 @@ type FDTable struct {
func (f *FDTable) saveDescriptorTable() map[int32]descriptor {
m := make(map[int32]descriptor)
+ f.mu.Lock()
+ defer f.mu.Unlock()
f.forEach(context.Background(), func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
m[fd] = descriptor{
file: file,
@@ -111,12 +108,16 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor {
func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) {
ctx := context.Background()
f.initNoLeakCheck() // Initialize table.
- f.used = 0
+ f.fdBitmap = bitmap.New(uint32(math.MaxUint16))
for fd, d := range m {
+ if fd < 0 {
+ panic(fmt.Sprintf("FD is not supposed to be negative. FD: %d", fd))
+ }
+
if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil {
panic("VFS1 or VFS2 files set")
}
-
+ f.fdBitmap.Add(uint32(fd))
// Note that we do _not_ need to acquire a extra table reference here. The
// table reference will already be accounted for in the file, so we drop the
// reference taken by set above.
@@ -189,8 +190,10 @@ func (f *FDTable) DecRef(ctx context.Context) {
func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags)) {
// retries tracks the number of failed TryIncRef attempts for the same FD.
retries := 0
- fd := int32(0)
- for {
+ fds := f.fdBitmap.ToSlice()
+ // Iterate through the fdBitmap.
+ for _, ufd := range fds {
+ fd := int32(ufd)
file, fileVFS2, flags, ok := f.getAll(fd)
if !ok {
break
@@ -218,7 +221,6 @@ func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File,
fileVFS2.DecRef(ctx)
}
retries = 0
- fd++
}
}
@@ -226,6 +228,8 @@ func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File,
func (f *FDTable) String() string {
var buf strings.Builder
ctx := context.Background()
+ f.mu.Lock()
+ defer f.mu.Unlock()
f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
switch {
case file != nil:
@@ -250,10 +254,10 @@ func (f *FDTable) String() string {
}
// NewFDs allocates new FDs guaranteed to be the lowest number available
-// greater than or equal to the fd parameter. All files will share the set
+// greater than or equal to the minFD parameter. All files will share the set
// flags. Success is guaranteed to be all or none.
-func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags FDFlags) (fds []int32, err error) {
- if fd < 0 {
+func (f *FDTable) NewFDs(ctx context.Context, minFD int32, files []*fs.File, flags FDFlags) (fds []int32, err error) {
+ if minFD < 0 {
// Don't accept negative FDs.
return nil, unix.EINVAL
}
@@ -267,31 +271,48 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
if lim.Cur != limits.Infinity {
end = int32(lim.Cur)
}
- if fd >= end {
+ if minFD+int32(len(files)) > end {
return nil, unix.EMFILE
}
}
f.mu.Lock()
- // From f.next to find available fd.
- if fd < f.next {
- fd = f.next
+ // max is used as the largest number in fdBitmap + 1.
+ max := int32(0)
+
+ if !f.fdBitmap.IsEmpty() {
+ max = int32(f.fdBitmap.Maximum())
+ max++
}
+ // Adjust max in case it is less than minFD.
+ if max < minFD {
+ max = minFD
+ }
// Install all entries.
- for i := fd; i < end && len(fds) < len(files); i++ {
- if d, _, _ := f.get(i); d == nil {
- // Set the descriptor.
- f.set(ctx, i, files[len(fds)], flags)
- fds = append(fds, i) // Record the file descriptor.
+ for len(fds) < len(files) {
+ // Try to use free bit in fdBitmap.
+ // If all bits in fdBitmap are used, expand fd to the max.
+ fd := f.fdBitmap.FirstZero(uint32(minFD))
+ if fd == math.MaxInt32 {
+ fd = uint32(max)
+ max++
+ }
+ if fd >= uint32(end) {
+ break
}
+ f.fdBitmap.Add(fd)
+ f.set(ctx, int32(fd), files[len(fds)], flags)
+ fds = append(fds, int32(fd))
+ minFD = int32(fd)
}
// Failure? Unwind existing FDs.
if len(fds) < len(files) {
for _, i := range fds {
f.set(ctx, i, nil, FDFlags{})
+ f.fdBitmap.Remove(uint32(i))
}
f.mu.Unlock()
@@ -305,20 +326,15 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags
return nil, unix.EMFILE
}
- if fd == f.next {
- // Update next search start position.
- f.next = fds[len(fds)-1] + 1
- }
-
f.mu.Unlock()
return fds, nil
}
// NewFDsVFS2 allocates new FDs guaranteed to be the lowest number available
-// greater than or equal to the fd parameter. All files will share the set
+// greater than or equal to the minFD parameter. All files will share the set
// flags. Success is guaranteed to be all or none.
-func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDescription, flags FDFlags) (fds []int32, err error) {
- if fd < 0 {
+func (f *FDTable) NewFDsVFS2(ctx context.Context, minFD int32, files []*vfs.FileDescription, flags FDFlags) (fds []int32, err error) {
+ if minFD < 0 {
// Don't accept negative FDs.
return nil, unix.EINVAL
}
@@ -332,31 +348,47 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes
if lim.Cur != limits.Infinity {
end = int32(lim.Cur)
}
- if fd >= end {
+ if minFD >= end {
return nil, unix.EMFILE
}
}
f.mu.Lock()
- // From f.next to find available fd.
- if fd < f.next {
- fd = f.next
+ // max is used as the largest number in fdBitmap + 1.
+ max := int32(0)
+
+ if !f.fdBitmap.IsEmpty() {
+ max = int32(f.fdBitmap.Maximum())
+ max++
}
- // Install all entries.
- for i := fd; i < end && len(fds) < len(files); i++ {
- if d, _, _ := f.getVFS2(i); d == nil {
- // Set the descriptor.
- f.setVFS2(ctx, i, files[len(fds)], flags)
- fds = append(fds, i) // Record the file descriptor.
- }
+ // Adjust max in case it is less than minFD.
+ if max < minFD {
+ max = minFD
}
+ for len(fds) < len(files) {
+ // Try to use free bit in fdBitmap.
+ // If all bits in fdBitmap are used, expand fd to the max.
+ fd := f.fdBitmap.FirstZero(uint32(minFD))
+ if fd == math.MaxInt32 {
+ fd = uint32(max)
+ max++
+ }
+ if fd >= uint32(end) {
+ break
+ }
+ f.fdBitmap.Add(fd)
+ f.setVFS2(ctx, int32(fd), files[len(fds)], flags)
+ fds = append(fds, int32(fd))
+ minFD = int32(fd)
+ }
// Failure? Unwind existing FDs.
if len(fds) < len(files) {
for _, i := range fds {
f.setVFS2(ctx, i, nil, FDFlags{})
+ f.fdBitmap.Remove(uint32(i))
}
f.mu.Unlock()
@@ -370,57 +402,19 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes
return nil, unix.EMFILE
}
- if fd == f.next {
- // Update next search start position.
- f.next = fds[len(fds)-1] + 1
- }
-
f.mu.Unlock()
return fds, nil
}
-// NewFDVFS2 allocates a file descriptor greater than or equal to minfd for
+// NewFDVFS2 allocates a file descriptor greater than or equal to minFD for
// the given file description. If it succeeds, it takes a reference on file.
-func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
- if minfd < 0 {
- // Don't accept negative FDs.
- return -1, unix.EINVAL
- }
-
- // Default limit.
- end := int32(math.MaxInt32)
-
- // Ensure we don't get past the provided limit.
- if limitSet := limits.FromContext(ctx); limitSet != nil {
- lim := limitSet.Get(limits.NumberOfFiles)
- if lim.Cur != limits.Infinity {
- end = int32(lim.Cur)
- }
- if minfd >= end {
- return -1, unix.EMFILE
- }
- }
-
- f.mu.Lock()
- defer f.mu.Unlock()
-
- // From f.next to find available fd.
- fd := minfd
- if fd < f.next {
- fd = f.next
- }
- for fd < end {
- if d, _, _ := f.getVFS2(fd); d == nil {
- f.setVFS2(ctx, fd, file, flags)
- if fd == f.next {
- // Update next search start position.
- f.next = fd + 1
- }
- return fd, nil
- }
- fd++
+func (f *FDTable) NewFDVFS2(ctx context.Context, minFD int32, file *vfs.FileDescription, flags FDFlags) (int32, error) {
+ files := []*vfs.FileDescription{file}
+ fileSlice, error := f.NewFDsVFS2(ctx, minFD, files, flags)
+ if error != nil {
+ return -1, error
}
- return -1, unix.EMFILE
+ return fileSlice[0], nil
}
// NewFDAt sets the file reference for the given FD. If there is an active
@@ -469,6 +463,11 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2
defer f.mu.Unlock()
df, dfVFS2 := f.setAll(ctx, fd, file, fileVFS2, flags)
+ // Add fd to fdBitmap.
+ if file != nil || fileVFS2 != nil {
+ f.fdBitmap.Add(uint32(fd))
+ }
+
return df, dfVFS2, nil
}
@@ -573,7 +572,9 @@ func (f *FDTable) GetVFS2(fd int32) (*vfs.FileDescription, FDFlags) {
// Precondition: The caller must be running on the task goroutine, or Task.mu
// must be locked.
func (f *FDTable) GetFDs(ctx context.Context) []int32 {
- fds := make([]int32, 0, int(atomic.LoadInt32(&f.used)))
+ f.mu.Lock()
+ defer f.mu.Unlock()
+ fds := make([]int32, 0, int(f.fdBitmap.GetNumOnes()))
f.forEach(ctx, func(fd int32, _ *fs.File, _ *vfs.FileDescription, _ FDFlags) {
fds = append(fds, fd)
})
@@ -583,13 +584,15 @@ func (f *FDTable) GetFDs(ctx context.Context) []int32 {
// Fork returns an independent FDTable.
func (f *FDTable) Fork(ctx context.Context) *FDTable {
clone := f.k.NewFDTable()
-
+ f.mu.Lock()
+ defer f.mu.Unlock()
f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
// The set function here will acquire an appropriate table
// reference for the clone. We don't need anything else.
if df, dfVFS2 := clone.setAll(ctx, fd, file, fileVFS2, flags); df != nil || dfVFS2 != nil {
panic("VFS1 or VFS2 files set")
}
+ clone.fdBitmap.Add(uint32(fd))
})
return clone
}
@@ -604,11 +607,6 @@ func (f *FDTable) Remove(ctx context.Context, fd int32) (*fs.File, *vfs.FileDesc
f.mu.Lock()
- // Update current available position.
- if fd < f.next {
- f.next = fd
- }
-
orig, orig2, _, _ := f.getAll(fd)
// Add reference for caller.
@@ -621,6 +619,7 @@ func (f *FDTable) Remove(ctx context.Context, fd int32) (*fs.File, *vfs.FileDesc
if orig != nil || orig2 != nil {
orig, orig2 = f.setAll(ctx, fd, nil, nil, FDFlags{}) // Zap entry.
+ f.fdBitmap.Remove(uint32(fd))
}
f.mu.Unlock()
@@ -644,16 +643,13 @@ func (f *FDTable) RemoveIf(ctx context.Context, cond func(*fs.File, *vfs.FileDes
f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) {
if cond(file, fileVFS2, flags) {
df, dfVFS2 := f.setAll(ctx, fd, nil, nil, FDFlags{}) // Clear from table.
+ f.fdBitmap.Remove(uint32(fd))
if df != nil {
files = append(files, df)
}
if dfVFS2 != nil {
filesVFS2 = append(filesVFS2, dfVFS2)
}
- // Update current available position.
- if fd < f.next {
- f.next = fd
- }
}
})
f.mu.Unlock()
diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go
index f17f9c59c..2b3e6ef71 100644
--- a/pkg/sentry/kernel/fd_table_unsafe.go
+++ b/pkg/sentry/kernel/fd_table_unsafe.go
@@ -15,9 +15,11 @@
package kernel
import (
+ "math"
"sync/atomic"
"unsafe"
+ "gvisor.dev/gvisor/pkg/bitmap"
"gvisor.dev/gvisor/pkg/context"
"gvisor.dev/gvisor/pkg/sentry/fs"
"gvisor.dev/gvisor/pkg/sentry/vfs"
@@ -44,6 +46,7 @@ func (f *FDTable) initNoLeakCheck() {
func (f *FDTable) init() {
f.initNoLeakCheck()
f.InitRefs()
+ f.fdBitmap = bitmap.New(uint32(math.MaxUint16))
}
// get gets a file entry.
@@ -162,14 +165,6 @@ func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2
}
}
- // Adjust used.
- switch {
- case orig == nil && desc != nil:
- atomic.AddInt32(&f.used, 1)
- case orig != nil && desc == nil:
- atomic.AddInt32(&f.used, -1)
- }
-
if orig != nil {
switch {
case orig.file != nil:
diff --git a/pkg/sentry/kernel/msgqueue/msgqueue.go b/pkg/sentry/kernel/msgqueue/msgqueue.go
index 3ce926950..c111297d7 100644
--- a/pkg/sentry/kernel/msgqueue/msgqueue.go
+++ b/pkg/sentry/kernel/msgqueue/msgqueue.go
@@ -119,14 +119,21 @@ type Queue struct {
type Message struct {
msgEntry
- // mType is an integer representing the type of the sent message.
- mType int64
+ // Type is an integer representing the type of the sent message.
+ Type int64
- // mText is an untyped block of memory.
- mText []byte
+ // Text is an untyped block of memory.
+ Text []byte
- // mSize is the size of mText.
- mSize uint64
+ // Size is the size of Text.
+ Size uint64
+}
+
+// Blocker is used for blocking Queue.Send, and Queue.Receive calls that serves
+// as an abstracted version of kernel.Task. kernel.Task is not directly used to
+// prevent circular dependencies.
+type Blocker interface {
+ Block(C <-chan struct{}) error
}
// FindOrCreate creates a new message queue or returns an existing one. See
@@ -186,6 +193,265 @@ func (r *Registry) Remove(id ipc.ID, creds *auth.Credentials) error {
return nil
}
+// FindByID returns the queue with the specified ID and an error if the ID
+// doesn't exist.
+func (r *Registry) FindByID(id ipc.ID) (*Queue, error) {
+ r.mu.Lock()
+ defer r.mu.Unlock()
+
+ mech := r.reg.FindByID(id)
+ if mech == nil {
+ return nil, linuxerr.EINVAL
+ }
+ return mech.(*Queue), nil
+}
+
+// Send appends a message to the message queue, and returns an error if sending
+// fails. See msgsnd(2).
+func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid int32) (err error) {
+ // Try to perform a non-blocking send using queue.append. If EWOULDBLOCK
+ // is returned, start the blocking procedure. Otherwise, return normally.
+ creds := auth.CredentialsFromContext(ctx)
+ if err := q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK {
+ return err
+ }
+
+ if !wait {
+ return linuxerr.EAGAIN
+ }
+
+ e, ch := waiter.NewChannelEntry(nil)
+ q.senders.EventRegister(&e, waiter.EventOut)
+
+ for {
+ if err = q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK {
+ break
+ }
+ b.Block(ch)
+ }
+
+ q.senders.EventUnregister(&e)
+ return err
+}
+
+// append appends a message to the queue's message list and notifies waiting
+// receivers that a message has been inserted. It returns an error if adding
+// the message would cause the queue to exceed its maximum capacity, which can
+// be used as a signal to block the task. Other errors should be returned as is.
+func (q *Queue) append(ctx context.Context, m Message, creds *auth.Credentials, pid int32) error {
+ if m.Type <= 0 {
+ return linuxerr.EINVAL
+ }
+
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ if !q.obj.CheckPermissions(creds, fs.PermMask{Write: true}) {
+ // The calling process does not have write permission on the message
+ // queue, and does not have the CAP_IPC_OWNER capability in the user
+ // namespace that governs its IPC namespace.
+ return linuxerr.EACCES
+ }
+
+ // Queue was removed while the process was waiting.
+ if q.dead {
+ return linuxerr.EIDRM
+ }
+
+ // Check if sufficient space is available (the queue isn't full.) From
+ // the man pages:
+ //
+ // "A message queue is considered to be full if either of the following
+ // conditions is true:
+ //
+ // • Adding a new message to the queue would cause the total number
+ // of bytes in the queue to exceed the queue's maximum size (the
+ // msg_qbytes field).
+ //
+ // • Adding another message to the queue would cause the total
+ // number of messages in the queue to exceed the queue's maximum
+ // size (the msg_qbytes field). This check is necessary to
+ // prevent an unlimited number of zero-length messages being
+ // placed on the queue. Although such messages contain no data,
+ // they nevertheless consume (locked) kernel memory."
+ //
+ // The msg_qbytes field in our implementation is q.maxBytes.
+ if m.Size+q.byteCount > q.maxBytes || q.messageCount+1 > q.maxBytes {
+ return linuxerr.EWOULDBLOCK
+ }
+
+ // Copy the message into the queue.
+ q.messages.PushBack(&m)
+
+ q.byteCount += m.Size
+ q.messageCount++
+ q.sendPID = pid
+ q.sendTime = ktime.NowFromContext(ctx)
+
+ // Notify receivers about the new message.
+ q.receivers.Notify(waiter.EventIn)
+
+ return nil
+}
+
+// Receive removes a message from the queue and returns it. See msgrcv(2).
+func (q *Queue) Receive(ctx context.Context, b Blocker, mType int64, maxSize int64, wait, truncate, except bool, pid int32) (msg *Message, err error) {
+ if maxSize < 0 || maxSize > maxMessageBytes {
+ return nil, linuxerr.EINVAL
+ }
+ max := uint64(maxSize)
+
+ // Try to perform a non-blocking receive using queue.pop. If EWOULDBLOCK
+ // is returned, start the blocking procedure. Otherwise, return normally.
+ creds := auth.CredentialsFromContext(ctx)
+ if msg, err := q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK {
+ return msg, err
+ }
+
+ if !wait {
+ return nil, linuxerr.ENOMSG
+ }
+
+ e, ch := waiter.NewChannelEntry(nil)
+ q.receivers.EventRegister(&e, waiter.EventIn)
+
+ for {
+ if msg, err = q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK {
+ break
+ }
+ b.Block(ch)
+ }
+ q.receivers.EventUnregister(&e)
+ return msg, err
+}
+
+// pop pops the first message from the queue that matches the given type. It
+// returns an error for all the cases specified in msgrcv(2). If the queue is
+// empty or no message of the specified type is available, a EWOULDBLOCK error
+// is returned, which can then be used as a signal to block the process or fail.
+func (q *Queue) pop(ctx context.Context, creds *auth.Credentials, mType int64, maxSize uint64, truncate, except bool, pid int32) (msg *Message, _ error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ if !q.obj.CheckPermissions(creds, fs.PermMask{Read: true}) {
+ // The calling process does not have read permission on the message
+ // queue, and does not have the CAP_IPC_OWNER capability in the user
+ // namespace that governs its IPC namespace.
+ return nil, linuxerr.EACCES
+ }
+
+ // Queue was removed while the process was waiting.
+ if q.dead {
+ return nil, linuxerr.EIDRM
+ }
+
+ if q.messages.Empty() {
+ return nil, linuxerr.EWOULDBLOCK
+ }
+
+ // Get a message from the queue.
+ switch {
+ case mType == 0:
+ msg = q.messages.Front()
+ case mType > 0:
+ msg = q.msgOfType(mType, except)
+ case mType < 0:
+ msg = q.msgOfTypeLessThan(-1 * mType)
+ }
+
+ // If no message exists, return a blocking singal.
+ if msg == nil {
+ return nil, linuxerr.EWOULDBLOCK
+ }
+
+ // Check message's size is acceptable.
+ if maxSize < msg.Size {
+ if !truncate {
+ return nil, linuxerr.E2BIG
+ }
+ msg.Size = maxSize
+ msg.Text = msg.Text[:maxSize+1]
+ }
+
+ q.messages.Remove(msg)
+
+ q.byteCount -= msg.Size
+ q.messageCount--
+ q.receivePID = pid
+ q.receiveTime = ktime.NowFromContext(ctx)
+
+ // Notify senders about available space.
+ q.senders.Notify(waiter.EventOut)
+
+ return msg, nil
+}
+
+// Copy copies a message from the queue without deleting it. If no message
+// exists, an error is returned. See msgrcv(MSG_COPY).
+func (q *Queue) Copy(mType int64) (*Message, error) {
+ q.mu.Lock()
+ defer q.mu.Unlock()
+
+ if mType < 0 || q.messages.Empty() {
+ return nil, linuxerr.ENOMSG
+ }
+
+ msg := q.msgAtIndex(mType)
+ if msg == nil {
+ return nil, linuxerr.ENOMSG
+ }
+ return msg, nil
+}
+
+// msgOfType returns the first message with the specified type, nil if no
+// message is found. If except is true, the first message of a type not equal
+// to mType will be returned.
+//
+// Precondition: caller must hold q.mu.
+func (q *Queue) msgOfType(mType int64, except bool) *Message {
+ if except {
+ for msg := q.messages.Front(); msg != nil; msg = msg.Next() {
+ if msg.Type != mType {
+ return msg
+ }
+ }
+ return nil
+ }
+
+ for msg := q.messages.Front(); msg != nil; msg = msg.Next() {
+ if msg.Type == mType {
+ return msg
+ }
+ }
+ return nil
+}
+
+// msgOfTypeLessThan return the the first message with the lowest type less
+// than or equal to mType, nil if no such message exists.
+//
+// Precondition: caller must hold q.mu.
+func (q *Queue) msgOfTypeLessThan(mType int64) (m *Message) {
+ min := mType
+ for msg := q.messages.Front(); msg != nil; msg = msg.Next() {
+ if msg.Type <= mType && msg.Type < min {
+ m = msg
+ min = msg.Type
+ }
+ }
+ return m
+}
+
+// msgAtIndex returns a pointer to a message at given index, nil if non exits.
+//
+// Precondition: caller must hold q.mu.
+func (q *Queue) msgAtIndex(mType int64) *Message {
+ msg := q.messages.Front()
+ for ; mType != 0 && msg != nil; mType-- {
+ msg = msg.Next()
+ }
+ return msg
+}
+
// Lock implements ipc.Mechanism.Lock.
func (q *Queue) Lock() {
q.mu.Lock()
diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go
index 28a613a54..8fd8287b3 100644
--- a/pkg/sentry/platform/kvm/bluepill_fault.go
+++ b/pkg/sentry/platform/kvm/bluepill_fault.go
@@ -101,7 +101,7 @@ func handleBluepillFault(m *machine, physical uintptr, phyRegions []physicalRegi
// Store the physical address in the slot. This is used to
// avoid calls to handleBluepillFault in the future (see
// machine.mapPhysical).
- atomic.StoreUintptr(&m.usedSlots[slot], physical)
+ atomic.StoreUintptr(&m.usedSlots[slot], physicalStart)
// Successfully added region; we can increment nextSlot and
// allow another set to proceed here.
atomic.StoreUint32(&m.nextSlot, slot+1)
diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go
index f63ab6aba..0f0c1e73b 100644
--- a/pkg/sentry/platform/kvm/bluepill_unsafe.go
+++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go
@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-//go:build go1.12 && !go1.18
-// +build go1.12,!go1.18
+//go:build go1.12
+// +build go1.12
-// Check go:linkname function signatures when updating Go version.
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
package kvm
@@ -28,7 +30,7 @@ import (
)
//go:linkname throw runtime.throw
-func throw(string)
+func throw(s string)
// vCPUPtr returns a CPU for the given address.
//
diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go
index 1b5d5f66e..e7092a756 100644
--- a/pkg/sentry/platform/kvm/machine.go
+++ b/pkg/sentry/platform/kvm/machine.go
@@ -70,7 +70,7 @@ type machine struct {
// tscControl checks whether cpu supports TSC scaling
tscControl bool
- // usedSlots is the set of used physical addresses (sorted).
+ // usedSlots is the set of used physical addresses (not sorted).
usedSlots []uintptr
// nextID is the next vCPU ID.
@@ -296,13 +296,20 @@ func newMachine(vm int) (*machine, error) {
return m, nil
}
-// hasSlot returns true iff the given address is mapped.
+// hasSlot returns true if the given address is mapped.
//
// This must be done via a linear scan.
//
//go:nosplit
func (m *machine) hasSlot(physical uintptr) bool {
- for i := 0; i < len(m.usedSlots); i++ {
+ slotLen := int(atomic.LoadUint32(&m.nextSlot))
+ // When slots are being updated, nextSlot is ^uint32(0). As this situation
+ // is less likely happen, we just set the slotLen to m.maxSlots, and scan
+ // the whole usedSlots array.
+ if slotLen == int(^uint32(0)) {
+ slotLen = m.maxSlots
+ }
+ for i := 0; i < slotLen; i++ {
if p := atomic.LoadUintptr(&m.usedSlots[i]); p == physical {
return true
}
diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go
index 35660e827..cc3a1253b 100644
--- a/pkg/sentry/platform/kvm/machine_unsafe.go
+++ b/pkg/sentry/platform/kvm/machine_unsafe.go
@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-//go:build go1.12 && !go1.18
-// +build go1.12,!go1.18
+//go:build go1.12
+// +build go1.12
-// Check go:linkname function signatures when updating Go version.
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
package kvm
diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go
index ffd4665f4..304722200 100644
--- a/pkg/sentry/platform/ptrace/subprocess_unsafe.go
+++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go
@@ -12,10 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-//go:build go1.12 && !go1.18
-// +build go1.12,!go1.18
+//go:build go1.12
+// +build go1.12
-// Check go:linkname function signatures when updating Go version.
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
package ptrace
diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go
index 7a4e78a5f..61111ac6c 100644
--- a/pkg/sentry/socket/hostinet/stack.go
+++ b/pkg/sentry/socket/hostinet/stack.go
@@ -309,6 +309,11 @@ func (s *Stack) Interfaces() map[int32]inet.Interface {
return interfaces
}
+// RemoveInterface implements inet.Stack.RemoveInterface.
+func (*Stack) RemoveInterface(int32) error {
+ return linuxerr.EACCES
+}
+
// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
addrs := make(map[int32][]inet.InterfaceAddr)
@@ -319,12 +324,12 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
}
// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr.
-func (s *Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error {
+func (*Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error {
return linuxerr.EACCES
}
// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr.
-func (s *Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error {
+func (*Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error {
return linuxerr.EACCES
}
@@ -339,7 +344,7 @@ func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) {
}
// SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize.
-func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error {
+func (*Stack) SetTCPReceiveBufferSize(inet.TCPBufferSize) error {
return linuxerr.EACCES
}
@@ -349,7 +354,7 @@ func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) {
}
// SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize.
-func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error {
+func (*Stack) SetTCPSendBufferSize(inet.TCPBufferSize) error {
return linuxerr.EACCES
}
@@ -359,7 +364,7 @@ func (s *Stack) TCPSACKEnabled() (bool, error) {
}
// SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled.
-func (s *Stack) SetTCPSACKEnabled(bool) error {
+func (*Stack) SetTCPSACKEnabled(bool) error {
return linuxerr.EACCES
}
@@ -369,7 +374,7 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) {
}
// SetTCPRecovery implements inet.Stack.SetTCPRecovery.
-func (s *Stack) SetTCPRecovery(inet.TCPLossRecovery) error {
+func (*Stack) SetTCPRecovery(inet.TCPLossRecovery) error {
return linuxerr.EACCES
}
@@ -470,19 +475,19 @@ func (s *Stack) RouteTable() []inet.Route {
}
// Resume implements inet.Stack.Resume.
-func (s *Stack) Resume() {}
+func (*Stack) Resume() {}
// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints.
-func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil }
+func (*Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil }
// CleanupEndpoints implements inet.Stack.CleanupEndpoints.
-func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil }
+func (*Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil }
// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints.
-func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
+func (*Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {}
// SetForwarding implements inet.Stack.SetForwarding.
-func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error {
+func (*Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error {
return linuxerr.EACCES
}
@@ -493,6 +498,6 @@ func (*Stack) PortRange() (uint16, uint16) {
}
// SetPortRange implements inet.Stack.SetPortRange.
-func (*Stack) SetPortRange(start uint16, end uint16) error {
+func (*Stack) SetPortRange(uint16, uint16) error {
return linuxerr.EACCES
}
diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go
index 86f6419dc..d526acb73 100644
--- a/pkg/sentry/socket/netlink/route/protocol.go
+++ b/pkg/sentry/socket/netlink/route/protocol.go
@@ -161,6 +161,47 @@ func (p *Protocol) getLink(ctx context.Context, msg *netlink.Message, ms *netlin
return nil
}
+// delLink handles RTM_DELLINK requests.
+func (p *Protocol) delLink(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error {
+ stack := inet.StackFromContext(ctx)
+ if stack == nil {
+ // No network stack.
+ return syserr.ErrProtocolNotSupported
+ }
+
+ var ifinfomsg linux.InterfaceInfoMessage
+ attrs, ok := msg.GetData(&ifinfomsg)
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+ if ifinfomsg.Index == 0 {
+ // The index is unspecified, search by the interface name.
+ ahdr, value, _, ok := attrs.ParseFirst()
+ if !ok {
+ return syserr.ErrInvalidArgument
+ }
+ switch ahdr.Type {
+ case linux.IFLA_IFNAME:
+ if len(value) < 1 {
+ return syserr.ErrInvalidArgument
+ }
+ ifname := string(value[:len(value)-1])
+ for idx, ifa := range stack.Interfaces() {
+ if ifname == ifa.Name {
+ ifinfomsg.Index = idx
+ break
+ }
+ }
+ default:
+ return syserr.ErrInvalidArgument
+ }
+ if ifinfomsg.Index == 0 {
+ return syserr.ErrNoDevice
+ }
+ }
+ return syserr.FromError(stack.RemoveInterface(ifinfomsg.Index))
+}
+
// addNewLinkMessage appends RTM_NEWLINK message for the given interface into
// the message set.
func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) {
@@ -537,6 +578,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms
switch hdr.Type {
case linux.RTM_GETLINK:
return p.getLink(ctx, msg, ms)
+ case linux.RTM_DELLINK:
+ return p.delLink(ctx, msg, ms)
case linux.RTM_GETROUTE:
return p.dumpRoutes(ctx, msg, ms)
case linux.RTM_NEWADDR:
diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go
index 0fd0ad32c..208ab9909 100644
--- a/pkg/sentry/socket/netstack/stack.go
+++ b/pkg/sentry/socket/netstack/stack.go
@@ -71,6 +71,12 @@ func (s *Stack) Interfaces() map[int32]inet.Interface {
return is
}
+// RemoveInterface implements inet.Stack.RemoveInterface.
+func (s *Stack) RemoveInterface(idx int32) error {
+ nic := tcpip.NICID(idx)
+ return syserr.TranslateNetstackError(s.Stack.RemoveNIC(nic)).ToError()
+}
+
// InterfaceAddrs implements inet.Stack.InterfaceAddrs.
func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr {
nicAddrs := make(map[int32][]inet.InterfaceAddr)
diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD
index ccccce6a9..b5a371d9a 100644
--- a/pkg/sentry/syscalls/linux/BUILD
+++ b/pkg/sentry/syscalls/linux/BUILD
@@ -86,6 +86,7 @@ go_library(
"//pkg/sentry/kernel/eventfd",
"//pkg/sentry/kernel/fasync",
"//pkg/sentry/kernel/ipc",
+ "//pkg/sentry/kernel/msgqueue",
"//pkg/sentry/kernel/pipe",
"//pkg/sentry/kernel/sched",
"//pkg/sentry/kernel/shm",
diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go
index 6f44d767b..1ead3c7e8 100644
--- a/pkg/sentry/syscalls/linux/linux64.go
+++ b/pkg/sentry/syscalls/linux/linux64.go
@@ -122,8 +122,8 @@ var AMD64 = &kernel.SyscallTable{
66: syscalls.Supported("semctl", Semctl),
67: syscalls.Supported("shmdt", Shmdt),
68: syscalls.Supported("msgget", Msgget),
- 69: syscalls.ErrorWithEvent("msgsnd", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 70: syscalls.ErrorWithEvent("msgrcv", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 69: syscalls.Supported("msgsnd", Msgsnd),
+ 70: syscalls.Supported("msgrcv", Msgrcv),
71: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}),
72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil),
73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil),
@@ -618,8 +618,8 @@ var ARM64 = &kernel.SyscallTable{
185: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921)
186: syscalls.Supported("msgget", Msgget),
187: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}),
- 188: syscalls.ErrorWithEvent("msgrcv", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
- 189: syscalls.ErrorWithEvent("msgsnd", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921)
+ 188: syscalls.Supported("msgrcv", Msgrcv),
+ 189: syscalls.Supported("msgsnd", Msgsnd),
190: syscalls.Supported("semget", Semget),
191: syscalls.Supported("semctl", Semctl),
192: syscalls.Supported("semtimedop", Semtimedop),
diff --git a/pkg/sentry/syscalls/linux/sys_msgqueue.go b/pkg/sentry/syscalls/linux/sys_msgqueue.go
index 3476e218d..5259ade90 100644
--- a/pkg/sentry/syscalls/linux/sys_msgqueue.go
+++ b/pkg/sentry/syscalls/linux/sys_msgqueue.go
@@ -17,10 +17,12 @@ package linux
import (
"gvisor.dev/gvisor/pkg/abi/linux"
"gvisor.dev/gvisor/pkg/errors/linuxerr"
+ "gvisor.dev/gvisor/pkg/marshal/primitive"
"gvisor.dev/gvisor/pkg/sentry/arch"
"gvisor.dev/gvisor/pkg/sentry/kernel"
"gvisor.dev/gvisor/pkg/sentry/kernel/auth"
"gvisor.dev/gvisor/pkg/sentry/kernel/ipc"
+ "gvisor.dev/gvisor/pkg/sentry/kernel/msgqueue"
)
// Msgget implements msgget(2).
@@ -41,6 +43,89 @@ func Msgget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal
return uintptr(queue.ID()), nil, nil
}
+// Msgsnd implements msgsnd(2).
+func Msgsnd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := ipc.ID(args[0].Int())
+ msgAddr := args[1].Pointer()
+ size := args[2].Int64()
+ flag := args[3].Int()
+
+ if size < 0 || size > linux.MSGMAX {
+ return 0, nil, linuxerr.EINVAL
+ }
+
+ wait := flag&linux.IPC_NOWAIT != linux.IPC_NOWAIT
+ pid := int32(t.ThreadGroup().ID())
+
+ buf := linux.MsgBuf{
+ Text: make([]byte, size),
+ }
+ if _, err := buf.CopyIn(t, msgAddr); err != nil {
+ return 0, nil, err
+ }
+
+ queue, err := t.IPCNamespace().MsgqueueRegistry().FindByID(id)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ msg := msgqueue.Message{
+ Type: int64(buf.Type),
+ Text: buf.Text,
+ Size: uint64(size),
+ }
+ return 0, nil, queue.Send(t, msg, t, wait, pid)
+}
+
+// Msgrcv implements msgrcv(2).
+func Msgrcv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
+ id := ipc.ID(args[0].Int())
+ msgAddr := args[1].Pointer()
+ size := args[2].Int64()
+ mType := args[3].Int64()
+ flag := args[4].Int()
+
+ wait := flag&linux.IPC_NOWAIT != linux.IPC_NOWAIT
+ except := flag&linux.MSG_EXCEPT == linux.MSG_EXCEPT
+ truncate := flag&linux.MSG_NOERROR == linux.MSG_NOERROR
+
+ msgCopy := flag&linux.MSG_COPY == linux.MSG_COPY
+
+ msg, err := receive(t, id, mType, size, msgCopy, wait, truncate, except)
+ if err != nil {
+ return 0, nil, err
+ }
+
+ buf := linux.MsgBuf{
+ Type: primitive.Int64(msg.Type),
+ Text: msg.Text,
+ }
+ if _, err := buf.CopyOut(t, msgAddr); err != nil {
+ return 0, nil, err
+ }
+ return uintptr(msg.Size), nil, nil
+}
+
+// receive returns a message from the queue with the given ID. If msgCopy is
+// true, a message is copied from the queue without being removed. Otherwise,
+// a message is removed from the queue and returned.
+func receive(t *kernel.Task, id ipc.ID, mType int64, maxSize int64, msgCopy, wait, truncate, except bool) (*msgqueue.Message, error) {
+ pid := int32(t.ThreadGroup().ID())
+
+ queue, err := t.IPCNamespace().MsgqueueRegistry().FindByID(id)
+ if err != nil {
+ return nil, err
+ }
+
+ if msgCopy {
+ if wait || except {
+ return nil, linuxerr.EINVAL
+ }
+ return queue.Copy(mType)
+ }
+ return queue.Receive(t, t, mType, maxSize, wait, truncate, except, pid)
+}
+
// Msgctl implements msgctl(2).
func Msgctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) {
id := ipc.ID(args[0].Int())
diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go
index a16b6b4d6..2ef1e6404 100644
--- a/pkg/sentry/syscalls/linux/sys_prctl.go
+++ b/pkg/sentry/syscalls/linux/sys_prctl.go
@@ -219,6 +219,21 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
}
return 0, nil, t.DropBoundingCapability(cp)
+ case linux.PR_SET_CHILD_SUBREAPER:
+ // "If arg2 is nonzero, set the "child subreaper" attribute of
+ // the calling process; if arg2 is zero, unset the attribute."
+ //
+ // TODO(gvisor.dev/issues/2323): We only support setting, and
+ // only if the task is already TID 1 in the PID namespace,
+ // because it already acts as a subreaper in that case.
+ isPid1 := t.PIDNamespace().IDOfTask(t) == kernel.InitTID
+ if args[1].Int() != 0 && isPid1 {
+ return 0, nil, nil
+ }
+
+ t.Kernel().EmitUnimplementedEvent(t)
+ return 0, nil, linuxerr.EINVAL
+
case linux.PR_GET_TIMING,
linux.PR_SET_TIMING,
linux.PR_GET_TSC,
@@ -230,7 +245,6 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall
linux.PR_MCE_KILL,
linux.PR_MCE_KILL_GET,
linux.PR_GET_TID_ADDRESS,
- linux.PR_SET_CHILD_SUBREAPER,
linux.PR_GET_CHILD_SUBREAPER,
linux.PR_GET_THP_DISABLE,
linux.PR_SET_THP_DISABLE,
diff --git a/pkg/sync/goyield_unsafe.go b/pkg/sync/goyield_unsafe.go
index 8639bb64e..757edbaba 100644
--- a/pkg/sync/goyield_unsafe.go
+++ b/pkg/sync/goyield_unsafe.go
@@ -3,10 +3,12 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-//go:build go1.14 && !go1.18
-// +build go1.14,!go1.18
+//go:build go1.14
+// +build go1.14
-// Check go:linkname function signatures when updating Go version.
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
package sync
diff --git a/pkg/sync/runtime_unsafe.go b/pkg/sync/runtime_unsafe.go
index 1d9cf304e..49d4109a9 100644
--- a/pkg/sync/runtime_unsafe.go
+++ b/pkg/sync/runtime_unsafe.go
@@ -6,8 +6,13 @@
//go:build go1.13 && !go1.18
// +build go1.13,!go1.18
-// Check go:linkname function signatures, type definitions, and constants when
-// updating Go version.
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
+
+// Check type definitions and constants when updating Go version.
+//
+// TODO(b/165820485): add these checks to checklinkname.
package sync
@@ -109,10 +114,10 @@ type maptype struct {
// These functions are only used within the sync package.
//go:linkname semacquire sync.runtime_Semacquire
-func semacquire(s *uint32)
+func semacquire(addr *uint32)
//go:linkname semrelease sync.runtime_Semrelease
-func semrelease(s *uint32, handoff bool, skipframes int)
+func semrelease(addr *uint32, handoff bool, skipframes int)
//go:linkname canSpin sync.runtime_canSpin
func canSpin(i int) bool
diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go
index e8e716db0..48356c343 100644
--- a/pkg/tcpip/link/fdbased/endpoint.go
+++ b/pkg/tcpip/link/fdbased/endpoint.go
@@ -56,6 +56,7 @@ import (
// linkDispatcher reads packets from the link FD and dispatches them to the
// NetworkDispatcher.
type linkDispatcher interface {
+ stop()
dispatch() (bool, tcpip.Error)
}
@@ -381,16 +382,27 @@ func isSocketFD(fd int) (bool, error) {
// Attach launches the goroutine that reads packets from the file descriptor and
// dispatches them via the provided dispatcher.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
- e.dispatcher = dispatcher
- // Link endpoints are not savable. When transportation endpoints are
- // saved, they stop sending outgoing packets and all incoming packets
- // are rejected.
- for i := range e.inboundDispatchers {
- e.wg.Add(1)
- go func(i int) { // S/R-SAFE: See above.
- e.dispatchLoop(e.inboundDispatchers[i])
- e.wg.Done()
- }(i)
+ // nil means the NIC is being removed.
+ if dispatcher == nil && e.dispatcher != nil {
+ for _, dispatcher := range e.inboundDispatchers {
+ dispatcher.stop()
+ }
+ e.Wait()
+ e.dispatcher = nil
+ return
+ }
+ if dispatcher != nil && e.dispatcher == nil {
+ e.dispatcher = dispatcher
+ // Link endpoints are not savable. When transportation endpoints are
+ // saved, they stop sending outgoing packets and all incoming packets
+ // are rejected.
+ for i := range e.inboundDispatchers {
+ e.wg.Add(1)
+ go func(i int) { // S/R-SAFE: See above.
+ e.dispatchLoop(e.inboundDispatchers[i])
+ e.wg.Done()
+ }(i)
+ }
}
}
diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go
index bfae34ab9..3f516cab5 100644
--- a/pkg/tcpip/link/fdbased/mmap.go
+++ b/pkg/tcpip/link/fdbased/mmap.go
@@ -114,6 +114,7 @@ func (t tPacketHdr) Payload() []byte {
// packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets.
// See: mmap_amd64_unsafe.go for implementation details.
type packetMMapDispatcher struct {
+ stopFd
// fd is the file descriptor used to send and receive packets.
fd int
@@ -129,18 +130,18 @@ type packetMMapDispatcher struct {
ringOffset int
}
-func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) {
+func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, bool, tcpip.Error) {
hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:])
for hdr.tpStatus()&tpStatusUser == 0 {
- event := rawfile.PollEvent{
- FD: int32(d.fd),
- Events: unix.POLLIN | unix.POLLERR,
- }
- if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 {
+ stopped, errno := rawfile.BlockingPollUntilStopped(d.efd, d.fd, unix.POLLIN|unix.POLLERR)
+ if errno != 0 {
if errno == unix.EINTR {
continue
}
- return nil, rawfile.TranslateErrno(errno)
+ return nil, stopped, rawfile.TranslateErrno(errno)
+ }
+ if stopped {
+ return nil, true, nil
}
if hdr.tpStatus()&tpStatusCopy != 0 {
// This frame is truncated so skip it after flipping the
@@ -158,14 +159,14 @@ func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) {
// Release packet to kernel.
hdr.setTPStatus(tpStatusKernel)
d.ringOffset = (d.ringOffset + 1) % tpFrameNR
- return pkt, nil
+ return pkt, false, nil
}
// dispatch reads packets from an mmaped ring buffer and dispatches them to the
// network stack.
func (d *packetMMapDispatcher) dispatch() (bool, tcpip.Error) {
- pkt, err := d.readMMappedPacket()
- if err != nil {
+ pkt, stopped, err := d.readMMappedPacket()
+ if err != nil || stopped {
return false, err
}
var (
diff --git a/pkg/tcpip/link/fdbased/mmap_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go
index 58d5dfeef..5b786169a 100644
--- a/pkg/tcpip/link/fdbased/mmap_unsafe.go
+++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go
@@ -47,9 +47,14 @@ func (t tPacketHdr) setTPStatus(status uint32) {
}
func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ stopFd, err := newStopFd()
+ if err != nil {
+ return nil, err
+ }
d := &packetMMapDispatcher{
- fd: fd,
- e: e,
+ stopFd: stopFd,
+ fd: fd,
+ e: e,
}
pageSize := unix.Getpagesize()
if tpBlockSize%pageSize != 0 {
diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go
index ab2855a63..fab34c5fa 100644
--- a/pkg/tcpip/link/fdbased/packet_dispatchers.go
+++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go
@@ -18,6 +18,8 @@
package fdbased
import (
+ "fmt"
+
"golang.org/x/sys/unix"
"gvisor.dev/gvisor/pkg/tcpip"
"gvisor.dev/gvisor/pkg/tcpip/buffer"
@@ -114,9 +116,36 @@ func (b *iovecBuffer) pullViews(n int) buffer.VectorisedView {
return buffer.NewVectorisedView(n, views)
}
+// stopFd is an eventfd used to signal the stop of a dispatcher.
+type stopFd struct {
+ efd int
+}
+
+func newStopFd() (stopFd, error) {
+ efd, err := unix.Eventfd(0, unix.EFD_NONBLOCK)
+ if err != nil {
+ return stopFd{efd: -1}, fmt.Errorf("failed to create eventfd: %w", err)
+ }
+ return stopFd{efd: efd}, nil
+}
+
+// stop writes to the eventfd and notifies the dispatcher to stop. It does not
+// block.
+func (s *stopFd) stop() {
+ increment := []byte{1, 0, 0, 0, 0, 0, 0, 0}
+ if n, err := unix.Write(s.efd, increment); n != len(increment) || err != nil {
+ // There are two possible errors documented in eventfd(2) for writing:
+ // 1. We are writing 8 bytes and not 0xffffffffffffff, thus no EINVAL.
+ // 2. stop is only supposed to be called once, it can't reach the limit,
+ // thus no EAGAIN.
+ panic(fmt.Sprintf("write(efd) = (%d, %s), want (%d, nil)", n, err, len(increment)))
+ }
+}
+
// readVDispatcher uses readv() system call to read inbound packets and
// dispatches them.
type readVDispatcher struct {
+ stopFd
// fd is the file descriptor used to send and receive packets.
fd int
@@ -128,7 +157,15 @@ type readVDispatcher struct {
}
func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
- d := &readVDispatcher{fd: fd, e: e}
+ stopFd, err := newStopFd()
+ if err != nil {
+ return nil, err
+ }
+ d := &readVDispatcher{
+ stopFd: stopFd,
+ fd: fd,
+ e: e,
+ }
skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported
d.buf = newIovecBuffer(BufConfig, skipsVnetHdr)
return d, nil
@@ -136,8 +173,8 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
// dispatch reads one packet from the file descriptor and dispatches it.
func (d *readVDispatcher) dispatch() (bool, tcpip.Error) {
- n, err := rawfile.BlockingReadv(d.fd, d.buf.nextIovecs())
- if n == 0 || err != nil {
+ n, err := rawfile.BlockingReadvUntilStopped(d.efd, d.fd, d.buf.nextIovecs())
+ if n <= 0 || err != nil {
return false, err
}
@@ -184,6 +221,7 @@ func (d *readVDispatcher) dispatch() (bool, tcpip.Error) {
// recvMMsgDispatcher uses the recvmmsg system call to read inbound packets and
// dispatches them.
type recvMMsgDispatcher struct {
+ stopFd
// fd is the file descriptor used to send and receive packets.
fd int
@@ -207,7 +245,12 @@ const (
)
func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) {
+ stopFd, err := newStopFd()
+ if err != nil {
+ return nil, err
+ }
d := &recvMMsgDispatcher{
+ stopFd: stopFd,
fd: fd,
e: e,
bufs: make([]*iovecBuffer, MaxMsgsPerRecv),
@@ -235,8 +278,8 @@ func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) {
d.msgHdrs[k].Msg.SetIovlen(iovLen)
}
- nMsgs, err := rawfile.BlockingRecvMMsg(d.fd, d.msgHdrs)
- if err != nil {
+ nMsgs, err := rawfile.BlockingRecvMMsgUntilStopped(d.efd, d.fd, d.msgHdrs)
+ if nMsgs == -1 || err != nil {
return false, err
}
// Process each of received packets.
diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go
index b1a28491d..40bd5560b 100644
--- a/pkg/tcpip/link/qdisc/fifo/endpoint.go
+++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go
@@ -115,6 +115,13 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc
// Attach implements stack.LinkEndpoint.Attach.
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
+ // nil means the NIC is being removed.
+ if dispatcher == nil {
+ e.lower.Attach(nil)
+ e.Wait()
+ e.dispatcher = nil
+ return
+ }
e.dispatcher = dispatcher
e.lower.Attach(e)
}
diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
index da900c24b..0b7b9e3de 100644
--- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
+++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go
@@ -12,12 +12,13 @@
// See the License for the specific language governing permissions and
// limitations under the License.
-//go:build ((linux && amd64) || (linux && arm64)) && go1.12 && !go1.18
+//go:build ((linux && amd64) || (linux && arm64)) && go1.12
// +build linux,amd64 linux,arm64
// +build go1.12
-// +build !go1.18
-// Check go:linkname function signatures when updating Go version.
+// //go:linkname directives type-checked by checklinkname. Any other
+// non-linkname assumptions outside the Go 1 compatibility guarantee should
+// have an accompanied vet check or version guard build tag.
package rawfile
diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
index 53448a641..e76fc55b6 100644
--- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go
+++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go
@@ -170,46 +170,63 @@ func BlockingRead(fd int, b []byte) (int, tcpip.Error) {
}
}
-// BlockingReadv reads from a file descriptor that is set up as non-blocking and
-// stores the data in a list of iovecs buffers. If no data is available, it will
-// block in a poll() syscall until the file descriptor becomes readable.
-func BlockingReadv(fd int, iovecs []unix.Iovec) (int, tcpip.Error) {
+// BlockingReadvUntilStopped reads from a file descriptor that is set up as
+// non-blocking and stores the data in a list of iovecs buffers. If no data is
+// available, it will block in a poll() syscall until the file descriptor
+// becomes readable or stop is signalled (efd becomes readable). Returns -1 in
+// the latter case.
+func BlockingReadvUntilStopped(efd int, fd int, iovecs []unix.Iovec) (int, tcpip.Error) {
for {
n, _, e := unix.RawSyscall(unix.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs)))
if e == 0 {
return int(n), nil
}
- event := PollEvent{
- FD: int32(fd),
- Events: 1, // POLLIN
+ stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
+ if stopped {
+ return -1, nil
}
-
- _, e = BlockingPoll(&event, 1, nil)
if e != 0 && e != unix.EINTR {
return 0, TranslateErrno(e)
}
}
}
-// BlockingRecvMMsg reads from a file descriptor that is set up as non-blocking
-// and stores the received messages in a slice of MMsgHdr structures. If no data
-// is available, it will block in a poll() syscall until the file descriptor
-// becomes readable.
-func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) {
+// BlockingRecvMMsgUntilStopped reads from a file descriptor that is set up as
+// non-blocking and stores the received messages in a slice of MMsgHdr
+// structures. If no data is available, it will block in a poll() syscall until
+// the file descriptor becomes readable or stop is signalled (efd becomes
+// readable). Returns -1 in the latter case.
+func BlockingRecvMMsgUntilStopped(efd int, fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) {
for {
n, _, e := unix.RawSyscall6(unix.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), unix.MSG_DONTWAIT, 0, 0)
if e == 0 {
return int(n), nil
}
- event := PollEvent{
- FD: int32(fd),
- Events: 1, // POLLIN
+ stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN)
+ if stopped {
+ return -1, nil
}
-
- if _, e := BlockingPoll(&event, 1, nil); e != 0 && e != unix.EINTR {
+ if e != 0 && e != unix.EINTR {
return 0, TranslateErrno(e)
}
}
}
+
+// BlockingPollUntilStopped polls for events on fd or until a stop is signalled
+// on the event fd efd. Returns true if stopped, i.e., efd has event POLLIN.
+func BlockingPollUntilStopped(efd int, fd int, events int16) (bool, unix.Errno) {
+ pevents := [...]PollEvent{
+ {
+ FD: int32(efd),
+ Events: unix.POLLIN,
+ },
+ {
+ FD: int32(fd),
+ Events: events,
+ },
+ }
+ _, errno := BlockingPoll(&pevents[0], len(pevents), nil)
+ return pevents[0].Revents&unix.POLLIN != 0, errno
+}
diff --git a/pkg/tcpip/link/sniffer/pcap.go b/pkg/tcpip/link/sniffer/pcap.go
index 3bb864ed2..d3edede63 100644
--- a/pkg/tcpip/link/sniffer/pcap.go
+++ b/pkg/tcpip/link/sniffer/pcap.go
@@ -14,7 +14,14 @@
package sniffer
-import "time"
+import (
+ "encoding"
+ "encoding/binary"
+ "time"
+
+ "gvisor.dev/gvisor/pkg/tcpip"
+ "gvisor.dev/gvisor/pkg/tcpip/stack"
+)
type pcapHeader struct {
// MagicNumber is the file magic number.
@@ -39,25 +46,38 @@ type pcapHeader struct {
Network uint32
}
-type pcapPacketHeader struct {
- // Seconds is the timestamp seconds.
- Seconds uint32
-
- // Microseconds is the timestamp microseconds.
- Microseconds uint32
+var _ encoding.BinaryMarshaler = (*pcapPacket)(nil)
- // IncludedLength is the number of octets of packet saved in file.
- IncludedLength uint32
-
- // OriginalLength is the actual length of packet.
- OriginalLength uint32
+type pcapPacket struct {
+ timestamp time.Time
+ packet *stack.PacketBuffer
+ maxCaptureLen int
}
-func newPCAPPacketHeader(now time.Time, incLen, orgLen uint32) pcapPacketHeader {
- return pcapPacketHeader{
- Seconds: uint32(now.Unix()),
- Microseconds: uint32(now.Nanosecond() / 1000),
- IncludedLength: incLen,
- OriginalLength: orgLen,
+func (p *pcapPacket) MarshalBinary() ([]byte, error) {
+ packetSize := p.packet.Size()
+ captureLen := p.maxCaptureLen
+ if packetSize < captureLen {
+ captureLen = packetSize
+ }
+ b := make([]byte, 16+captureLen)
+ binary.BigEndian.PutUint32(b[0:4], uint32(p.timestamp.Unix()))
+ binary.BigEndian.PutUint32(b[4:8], uint32(p.timestamp.Nanosecond()/1000))
+ binary.BigEndian.PutUint32(b[8:12], uint32(captureLen))
+ binary.BigEndian.PutUint32(b[12:16], uint32(packetSize))
+ w := tcpip.SliceWriter(b[16:])
+ for _, v := range p.packet.Views() {
+ if captureLen == 0 {
+ break
+ }
+ if len(v) > captureLen {
+ v = v[:captureLen]
+ }
+ n, err := w.Write(v)
+ if err != nil {
+ panic(err)
+ }
+ captureLen -= n
}
+ return b, nil
}
diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go
index 3df826f3c..28a172e71 100644
--- a/pkg/tcpip/link/sniffer/sniffer.go
+++ b/pkg/tcpip/link/sniffer/sniffer.go
@@ -151,33 +151,16 @@ func (e *endpoint) dumpPacket(dir direction, protocol tcpip.NetworkProtocolNumbe
logPacket(e.logPrefix, dir, protocol, pkt)
}
if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 {
- totalLength := pkt.Size()
- length := totalLength
- if max := int(e.maxPCAPLen); length > max {
- length = max
+ packet := pcapPacket{
+ timestamp: time.Now(),
+ packet: pkt,
+ maxCaptureLen: int(e.maxPCAPLen),
}
- packetHeader := newPCAPPacketHeader(time.Now(), uint32(length), uint32(totalLength))
- packet := make([]byte, binary.Size(packetHeader)+length)
- {
- writer := tcpip.SliceWriter(packet)
- if err := binary.Write(&writer, binary.BigEndian, packetHeader); err != nil {
- panic(err)
- }
- for _, b := range pkt.Views() {
- if length == 0 {
- break
- }
- if len(b) > length {
- b = b[:length]
- }
- n, err := writer.Write(b)
- if err != nil {
- panic(err)
- }
- length -= n
- }
+ b, err := packet.MarshalBinary()
+ if err != nil {
+ panic(err)
}
- if _, err := writer.Write(packet); err != nil {
+ if _, err := writer.Write(b); err != nil {
panic(err)
}
}
diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD
index c90974693..2257f728e 100644
--- a/pkg/tcpip/network/ipv4/BUILD
+++ b/pkg/tcpip/network/ipv4/BUILD
@@ -39,7 +39,6 @@ go_test(
"//pkg/tcpip/faketime",
"//pkg/tcpip/header",
"//pkg/tcpip/link/channel",
- "//pkg/tcpip/link/loopback",
"//pkg/tcpip/link/sniffer",
"//pkg/tcpip/network/arp",
"//pkg/tcpip/network/internal/testutil",
diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go
index 4a4448cf9..73407be67 100644
--- a/pkg/tcpip/network/ipv4/ipv4_test.go
+++ b/pkg/tcpip/network/ipv4/ipv4_test.go
@@ -32,7 +32,6 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/faketime"
"gvisor.dev/gvisor/pkg/tcpip/header"
"gvisor.dev/gvisor/pkg/tcpip/link/channel"
- "gvisor.dev/gvisor/pkg/tcpip/link/loopback"
"gvisor.dev/gvisor/pkg/tcpip/link/sniffer"
"gvisor.dev/gvisor/pkg/tcpip/network/arp"
iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil"
@@ -3339,7 +3338,7 @@ func TestCloseLocking(t *testing.T) {
defer wg.Done()
for i := 0; i < iterations; i++ {
- if err := s.CreateNIC(nicID2, loopback.New()); err != nil {
+ if err := s.CreateNIC(nicID2, stack.LinkEndpoint(channel.New(0, defaultMTU, ""))); err != nil {
t.Errorf("CreateNIC(%d, _): %s", nicID2, err)
return
}
diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go
index 81fabe29a..c73890c4c 100644
--- a/pkg/tcpip/stack/stack.go
+++ b/pkg/tcpip/stack/stack.go
@@ -780,6 +780,9 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error {
if !ok {
return &tcpip.ErrUnknownNICID{}
}
+ if nic.IsLoopback() {
+ return &tcpip.ErrNotSupported{}
+ }
delete(s.nics, id)
// Remove routes in-place. n tracks the number of routes written.
diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go
index 21951d05a..3089c0ef4 100644
--- a/pkg/tcpip/stack/stack_test.go
+++ b/pkg/tcpip/stack/stack_test.go
@@ -719,38 +719,59 @@ func TestRemoveUnknownNIC(t *testing.T) {
}
func TestRemoveNIC(t *testing.T) {
- const nicID = 1
+ for _, tt := range []struct {
+ name string
+ linkep stack.LinkEndpoint
+ expectErr tcpip.Error
+ }{
+ {
+ name: "loopback",
+ linkep: loopback.New(),
+ expectErr: &tcpip.ErrNotSupported{},
+ },
+ {
+ name: "channel",
+ linkep: channel.New(0, defaultMTU, ""),
+ expectErr: nil,
+ },
+ } {
+ t.Run(tt.name, func(t *testing.T) {
+ const nicID = 1
- s := stack.New(stack.Options{
- NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
- })
+ s := stack.New(stack.Options{
+ NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory},
+ })
- e := linkEPWithMockedAttach{
- LinkEndpoint: loopback.New(),
- }
- if err := s.CreateNIC(nicID, &e); err != nil {
- t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
- }
+ e := linkEPWithMockedAttach{
+ LinkEndpoint: tt.linkep,
+ }
+ if err := s.CreateNIC(nicID, &e); err != nil {
+ t.Fatalf("CreateNIC(%d, _) = %s", nicID, err)
+ }
- // NIC should be present in NICInfo and attached to a NetworkDispatcher.
- allNICInfo := s.NICInfo()
- if _, ok := allNICInfo[nicID]; !ok {
- t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
- }
- if !e.isAttached() {
- t.Fatal("link endpoint not attached to a network dispatcher")
- }
+ // NIC should be present in NICInfo and attached to a NetworkDispatcher.
+ allNICInfo := s.NICInfo()
+ if _, ok := allNICInfo[nicID]; !ok {
+ t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo)
+ }
+ if !e.isAttached() {
+ t.Fatal("link endpoint not attached to a network dispatcher")
+ }
- // Removing a NIC should remove it from NICInfo and e should be detached from
- // the NetworkDispatcher.
- if err := s.RemoveNIC(nicID); err != nil {
- t.Fatalf("s.RemoveNIC(%d): %s", nicID, err)
- }
- if nicInfo, ok := s.NICInfo()[nicID]; ok {
- t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo)
- }
- if e.isAttached() {
- t.Error("link endpoint for removed NIC still attached to a network dispatcher")
+ // Removing a NIC should remove it from NICInfo and e should be detached from
+ // the NetworkDispatcher.
+ if got, want := s.RemoveNIC(nicID), tt.expectErr; got != want {
+ t.Fatalf("got s.RemoveNIC(%d) = %s, want %s", nicID, got, want)
+ }
+ if tt.expectErr == nil {
+ if nicInfo, ok := s.NICInfo()[nicID]; ok {
+ t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo)
+ }
+ if e.isAttached() {
+ t.Error("link endpoint for removed NIC still attached to a network dispatcher")
+ }
+ }
+ })
}
}