diff options
Diffstat (limited to 'pkg')
42 files changed, 1534 insertions, 314 deletions
diff --git a/pkg/bitmap/BUILD b/pkg/bitmap/BUILD new file mode 100644 index 000000000..0f1e7006d --- /dev/null +++ b/pkg/bitmap/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "bitmap", + srcs = ["bitmap.go"], + visibility = ["//:sandbox"], +) + +go_test( + name = "bitmap_test", + size = "small", + srcs = ["bitmap_test.go"], + library = ":bitmap", +) diff --git a/pkg/bitmap/bitmap.go b/pkg/bitmap/bitmap.go new file mode 100644 index 000000000..12d2fc2b8 --- /dev/null +++ b/pkg/bitmap/bitmap.go @@ -0,0 +1,234 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package bitmap provides the implementation of bitmap. +package bitmap + +import ( + "math" + "math/bits" +) + +// Bitmap implements an efficient bitmap. +// +// +stateify savable +type Bitmap struct { + // numOnes is the number of ones in the bitmap. + numOnes uint32 + + // bitBlock holds the bits. The type of bitBlock is uint64 which means + // each number in bitBlock contains 64 entries. + bitBlock []uint64 +} + +// New create a new empty Bitmap. +func New(size uint32) Bitmap { + b := Bitmap{} + bSize := (size + 63) / 64 + b.bitBlock = make([]uint64, bSize) + return b +} + +// IsEmpty verifies whether the Bitmap is empty. +func (b *Bitmap) IsEmpty() bool { + return b.numOnes == 0 +} + +// Minimum return the smallest value in the Bitmap. +func (b *Bitmap) Minimum() uint32 { + for i := 0; i < len(b.bitBlock); i++ { + if w := b.bitBlock[i]; w != 0 { + r := bits.TrailingZeros64(w) + return uint32(r + i*64) + } + } + return math.MaxInt32 +} + +// FirstZero returns the first unset bit from the range [start, ). +func (b *Bitmap) FirstZero(start uint32) uint32 { + i, nbit := int(start/64), start%64 + n := len(b.bitBlock) + if i >= n { + return math.MaxInt32 + } + w := b.bitBlock[i] | ((1 << nbit) - 1) + for { + if w != ^uint64(0) { + r := bits.TrailingZeros64(^w) + return uint32(r + i*64) + } + i++ + if i == n { + break + } + w = b.bitBlock[i] + } + return math.MaxInt32 +} + +// Maximum return the largest value in the Bitmap. +func (b *Bitmap) Maximum() uint32 { + for i := len(b.bitBlock) - 1; i >= 0; i-- { + if w := b.bitBlock[i]; w != 0 { + r := bits.LeadingZeros64(w) + return uint32(i*64 + 63 - r) + } + } + return uint32(0) +} + +// Add add i to the Bitmap. +func (b *Bitmap) Add(i uint32) { + blockNum, mask := i/64, uint64(1)<<(i%64) + // if blockNum is out of range, extend b.bitBlock + if x, y := int(blockNum), len(b.bitBlock); x >= y { + b.bitBlock = append(b.bitBlock, make([]uint64, x-y+1)...) + } + oldBlock := b.bitBlock[blockNum] + newBlock := oldBlock | mask + if oldBlock != newBlock { + b.bitBlock[blockNum] = newBlock + b.numOnes++ + } +} + +// Remove i from the Bitmap. +func (b *Bitmap) Remove(i uint32) { + blockNum, mask := i/64, uint64(1)<<(i%64) + oldBlock := b.bitBlock[blockNum] + newBlock := oldBlock &^ mask + if oldBlock != newBlock { + b.bitBlock[blockNum] = newBlock + b.numOnes-- + } +} + +// Clone the Bitmap. +func (b *Bitmap) Clone() Bitmap { + bitmap := Bitmap{b.numOnes, make([]uint64, len(b.bitBlock))} + copy(bitmap.bitBlock, b.bitBlock[:]) + return bitmap +} + +// countOnesForBlocks count all 1 bits within b.bitBlock of begin and that of end. +// The begin block and end block are inclusive. +func (b *Bitmap) countOnesForBlocks(begin, end uint32) uint64 { + ones := uint64(0) + beginBlock := begin / 64 + endBlock := end / 64 + for i := beginBlock; i <= endBlock; i++ { + ones += uint64(bits.OnesCount64(b.bitBlock[i])) + } + return ones +} + +// countOnesForAllBlocks count all 1 bits in b.bitBlock. +func (b *Bitmap) countOnesForAllBlocks() uint64 { + ones := uint64(0) + for i := 0; i < len(b.bitBlock); i++ { + ones += uint64(bits.OnesCount64(b.bitBlock[i])) + } + return ones +} + +// flipRange flip the bits within range (begin and end). begin is inclusive and end is exclusive. +func (b *Bitmap) flipRange(begin, end uint32) { + end-- + beginBlock := begin / 64 + endBlock := end / 64 + if beginBlock == endBlock { + b.bitBlock[endBlock] ^= ((^uint64(0) << uint(begin%64)) & ((uint64(1) << (uint(end)%64 + 1)) - 1)) + } else { + b.bitBlock[beginBlock] ^= ^(^uint64(0) << uint(begin%64)) + for i := beginBlock; i < endBlock; i++ { + b.bitBlock[i] = ^b.bitBlock[i] + } + b.bitBlock[endBlock] ^= ((uint64(1) << (uint(end)%64 + 1)) - 1) + } +} + +// clearRange clear the bits within range (begin and end). begin is inclusive and end is exclusive. +func (b *Bitmap) clearRange(begin, end uint32) { + end-- + beginBlock := begin / 64 + endBlock := end / 64 + if beginBlock == endBlock { + b.bitBlock[beginBlock] &= (((uint64(1) << uint(begin%64)) - 1) | ^((uint64(1) << (uint(end)%64 + 1)) - 1)) + } else { + b.bitBlock[beginBlock] &= ((uint64(1) << uint(begin%64)) - 1) + for i := beginBlock + 1; i < endBlock; i++ { + b.bitBlock[i] &= ^b.bitBlock[i] + } + b.bitBlock[endBlock] &= ^((uint64(1) << (uint(end)%64 + 1)) - 1) + } +} + +// ClearRange clear bits within range (begin and end) for the Bitmap. begin is inclusive and end is exclusive. +func (b *Bitmap) ClearRange(begin, end uint32) { + blockRange := end/64 - begin/64 + // When the number of cleared blocks is larger than half of the length of b.bitBlock, + // counting 1s for the entire bitmap has better performance. + if blockRange > uint32(len(b.bitBlock)/2) { + b.clearRange(begin, end) + b.numOnes = uint32(b.countOnesForAllBlocks()) + } else { + oldRangeOnes := b.countOnesForBlocks(begin, end) + b.clearRange(begin, end) + newRangeOnes := b.countOnesForBlocks(begin, end) + b.numOnes += uint32(newRangeOnes - oldRangeOnes) + } +} + +// FlipRange flip bits within range (begin and end) for the Bitmap. begin is inclusive and end is exclusive. +func (b *Bitmap) FlipRange(begin, end uint32) { + blockRange := end/64 - begin/64 + // When the number of flipped blocks is larger than half of the length of b.bitBlock, + // counting 1s for the entire bitmap has better performance. + if blockRange > uint32(len(b.bitBlock)/2) { + b.flipRange(begin, end) + b.numOnes = uint32(b.countOnesForAllBlocks()) + } else { + oldRangeOnes := b.countOnesForBlocks(begin, end) + b.flipRange(begin, end) + newRangeOnes := b.countOnesForBlocks(begin, end) + b.numOnes += uint32(newRangeOnes - oldRangeOnes) + } +} + +// ToSlice transform the Bitmap into slice. For example, a bitmap of [0, 1, 0, 1] +// will return the slice [1, 3]. +func (b *Bitmap) ToSlice() []uint32 { + bitmapSlice := make([]uint32, 0, b.numOnes) + // base is the start number of a bitBlock + base := 0 + for i := 0; i < len(b.bitBlock); i++ { + bitBlock := b.bitBlock[i] + // Iterate through all the numbers held by this bit block. + for bitBlock != 0 { + // Extract the lowest set 1 bit. + j := bitBlock & -bitBlock + // Interpret the bit as the in32 number it represents and add it to result. + bitmapSlice = append(bitmapSlice, uint32((base + int(bits.OnesCount64(j-1))))) + bitBlock ^= j + } + base += 64 + } + return bitmapSlice +} + +// GetNumOnes return the the number of ones in the Bitmap. +func (b *Bitmap) GetNumOnes() uint32 { + return b.numOnes +} diff --git a/pkg/bitmap/bitmap_test.go b/pkg/bitmap/bitmap_test.go new file mode 100644 index 000000000..76ebd779f --- /dev/null +++ b/pkg/bitmap/bitmap_test.go @@ -0,0 +1,306 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bitmap + +import ( + "math" + "reflect" + "testing" +) + +// generateFilledSlice generates a slice fill with numbers. +func generateFilledSlice(min, max, length int) []uint32 { + if max == min { + return []uint32{uint32(min)} + } + if length > (max - min) { + return nil + } + randSlice := make([]uint32, length) + if length != 0 { + rangeNum := uint32((max - min) / length) + randSlice[0], randSlice[length-1] = uint32(min), uint32(max) + for i := 1; i < length-1; i++ { + randSlice[i] = randSlice[i-1] + rangeNum + } + } + return randSlice +} + +// generateFilledBitmap generates a Bitmap filled with fillNum of numbers, +// and returns the slice and bitmap. +func generateFilledBitmap(min, max, fillNum int) ([]uint32, Bitmap) { + bitmap := New(uint32(max)) + randSlice := generateFilledSlice(min, max, fillNum) + for i := 0; i < fillNum; i++ { + bitmap.Add(randSlice[i]) + } + return randSlice, bitmap +} + +func TestNewBitmap(t *testing.T) { + tests := []struct { + name string + size int + expectSize int + }{ + {"length 1", 1, 1}, + {"length 1024", 1024, 16}, + {"length 1025", 1025, 17}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + if bitmap := New(uint32(tt.size)); len(bitmap.bitBlock) != tt.expectSize { + t.Errorf("New created bitmap with %v, bitBlock size: %d, wanted: %d", tt.name, len(bitmap.bitBlock), tt.expectSize) + } + }) + } +} + +func TestAdd(t *testing.T) { + tests := []struct { + name string + bitmapSize int + addNum int + }{ + {"Add with null bitmap.bitBlock", 0, 10}, + {"Add without extending bitBlock", 64, 10}, + {"Add without extending bitblock with margin number", 63, 64}, + {"Add with extended one block", 1024, 1025}, + {"Add with extended more then one block", 1024, 2048}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + bitmap := New(uint32(tt.bitmapSize)) + bitmap.Add(uint32(tt.addNum)) + bitmapSlice := bitmap.ToSlice() + if bitmapSlice[0] != uint32(tt.addNum) { + t.Errorf("%v, get number: %d, wanted: %d.", tt.name, bitmapSlice[0], tt.addNum) + } + }) + } +} + +func TestRemove(t *testing.T) { + bitmap := New(uint32(1024)) + firstSlice := generateFilledSlice(0, 511, 50) + secondSlice := generateFilledSlice(512, 1024, 50) + for i := 0; i < 50; i++ { + bitmap.Add(firstSlice[i]) + bitmap.Add(secondSlice[i]) + } + + for i := 0; i < 50; i++ { + bitmap.Remove(firstSlice[i]) + } + bitmapSlice := bitmap.ToSlice() + if !reflect.DeepEqual(bitmapSlice, secondSlice) { + t.Errorf("After Remove() firstSlice, remained slice: %v, wanted: %v", bitmapSlice, secondSlice) + } + + for i := 0; i < 50; i++ { + bitmap.Remove(secondSlice[i]) + } + bitmapSlice = bitmap.ToSlice() + emptySlice := make([]uint32, 0) + if !reflect.DeepEqual(bitmapSlice, emptySlice) { + t.Errorf("After Remove secondSlice, remained slice: %v, wanted: %v", bitmapSlice, emptySlice) + } + +} + +// Verify flip bits within one bitBlock, one bit and bits cross multi bitBlocks. +func TestFlipRange(t *testing.T) { + tests := []struct { + name string + flipRangeMin int + flipRangeMax int + filledSliceLen int + }{ + {"Flip one number in bitmap", 77, 77, 1}, + {"Flip numbers within one bitBlocks", 5, 60, 20}, + {"Flip numbers that cross multi bitBlocks", 20, 1000, 300}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + fillSlice, bitmap := generateFilledBitmap(tt.flipRangeMin, tt.flipRangeMax, tt.filledSliceLen) + flipFillSlice := make([]uint32, 0) + for i, j := tt.flipRangeMin, 0; i <= tt.flipRangeMax; i++ { + if uint32(i) != fillSlice[j] { + flipFillSlice = append(flipFillSlice, uint32(i)) + } else { + j++ + } + } + + bitmap.FlipRange(uint32(tt.flipRangeMin), uint32(tt.flipRangeMax+1)) + flipBitmapSlice := bitmap.ToSlice() + if !reflect.DeepEqual(flipFillSlice, flipBitmapSlice) { + t.Errorf("%v, flipped slice: %v, wanted: %v", tt.name, flipBitmapSlice, flipFillSlice) + } + }) + } +} + +// Verify clear bits within one bitBlock, one bit and bits cross multi bitBlocks. +func TestClearRange(t *testing.T) { + tests := []struct { + name string + clearRangeMin int + clearRangeMax int + bitmapSize int + }{ + {"ClearRange clear one number", 5, 5, 64}, + {"ClearRange clear numbers within one bitBlock", 4, 61, 64}, + {"ClearRange clear numbers cross multi bitBlocks", 20, 254, 512}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + bitmap := New(uint32(tt.bitmapSize)) + bitmap.FlipRange(uint32(0), uint32(tt.bitmapSize)) + bitmap.ClearRange(uint32(tt.clearRangeMin), uint32(tt.clearRangeMax+1)) + clearedBitmapSlice := bitmap.ToSlice() + clearedSlice := make([]uint32, 0) + for i := 0; i < tt.bitmapSize; i++ { + if i < tt.clearRangeMin || i > tt.clearRangeMax { + clearedSlice = append(clearedSlice, uint32(i)) + } + } + if !reflect.DeepEqual(clearedSlice, clearedBitmapSlice) { + t.Errorf("%v, cleared slice: %v, wanted: %v", tt.name, clearedBitmapSlice, clearedSlice) + } + }) + + } +} + +func TestMinimum(t *testing.T) { + randSlice, bitmap := generateFilledBitmap(0, 1024, 200) + min := bitmap.Minimum() + if min != randSlice[0] { + t.Errorf("Minimum() returns: %v, wanted: %v", min, randSlice[0]) + } + + bitmap.ClearRange(uint32(0), uint32(200)) + min = bitmap.Minimum() + bitmapSlice := bitmap.ToSlice() + if min != bitmapSlice[0] { + t.Errorf("After ClearRange, Minimum() returns: %v, wanted: %v", min, bitmapSlice[0]) + } + + bitmap.FlipRange(uint32(2), uint32(11)) + min = bitmap.Minimum() + bitmapSlice = bitmap.ToSlice() + if min != bitmapSlice[0] { + t.Errorf("After Flip, Minimum() returns: %v, wanted: %v", min, bitmapSlice[0]) + } +} + +func TestMaximum(t *testing.T) { + randSlice, bitmap := generateFilledBitmap(0, 1024, 200) + max := bitmap.Maximum() + if max != randSlice[len(randSlice)-1] { + t.Errorf("Maximum() returns: %v, wanted: %v", max, randSlice[len(randSlice)-1]) + } + + bitmap.ClearRange(uint32(1000), uint32(1025)) + max = bitmap.Maximum() + bitmapSlice := bitmap.ToSlice() + if max != bitmapSlice[len(bitmapSlice)-1] { + t.Errorf("After ClearRange, Maximum() returns: %v, wanted: %v", max, bitmapSlice[len(bitmapSlice)-1]) + } + + bitmap.FlipRange(uint32(1001), uint32(1021)) + max = bitmap.Maximum() + bitmapSlice = bitmap.ToSlice() + if max != bitmapSlice[len(bitmapSlice)-1] { + t.Errorf("After Flip, Maximum() returns: %v, wanted: %v", max, bitmapSlice[len(bitmapSlice)-1]) + } +} + +func TestBitmapNumOnes(t *testing.T) { + randSlice, bitmap := generateFilledBitmap(0, 1024, 200) + bitmapOnes := bitmap.GetNumOnes() + if bitmapOnes != uint32(200) { + t.Errorf("GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 200) + } + // Remove 10 numbers. + for i := 5; i < 15; i++ { + bitmap.Remove(randSlice[i]) + } + bitmapOnes = bitmap.GetNumOnes() + if bitmapOnes != uint32(190) { + t.Errorf("After Remove 10 number, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 190) + } + // Remove the 10 number again, the length supposed not change. + for i := 5; i < 15; i++ { + bitmap.Remove(randSlice[i]) + } + bitmapOnes = bitmap.GetNumOnes() + if bitmapOnes != uint32(190) { + t.Errorf("After Remove the 10 number again, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 190) + } + + // Add 10 number. + for i := 1080; i < 1090; i++ { + bitmap.Add(uint32(i)) + } + bitmapOnes = bitmap.GetNumOnes() + if bitmapOnes != uint32(200) { + t.Errorf("After Add 10 number, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 200) + } + + // Add the 10 number again, the length supposed not change. + for i := 1080; i < 1090; i++ { + bitmap.Add(uint32(i)) + } + bitmapOnes = bitmap.GetNumOnes() + if bitmapOnes != uint32(200) { + t.Errorf("After Add the 10 number again, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 200) + } + + // Flip 10 bits from 0 to 1. + bitmap.FlipRange(uint32(1025), uint32(1035)) + bitmapOnes = bitmap.GetNumOnes() + if bitmapOnes != uint32(210) { + t.Errorf("After Flip, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 210) + } + + // ClearRange numbers range from [0, 1025). + bitmap.ClearRange(uint32(0), uint32(1025)) + bitmapOnes = bitmap.GetNumOnes() + if bitmapOnes != uint32(20) { + t.Errorf("After ClearRange, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 20) + } +} + +func TestFirstZero(t *testing.T) { + bitmap := New(uint32(1000)) + bitmap.FlipRange(200, 400) + for i, j := range map[uint32]uint32{0: 0, 201: 400, 200: 400, 199: 199, 400: 400, 10000: math.MaxInt32} { + v := bitmap.FirstZero(i) + if v != j { + t.Errorf("Minimum() returns: %v, wanted: %v", v, j) + } + } +} diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go index 09fc14787..bd8ceba19 100644 --- a/pkg/gohacks/gohacks_unsafe.go +++ b/pkg/gohacks/gohacks_unsafe.go @@ -15,7 +15,13 @@ //go:build go1.13 && !go1.18 // +build go1.13,!go1.18 -// Check type signatures when updating Go version. +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. + +// Check type signatures and Noescape when updating Go version. +// +// TODO(b/165820485): add these checks to checklinkname. // Package gohacks contains utilities for subverting the Go compiler. package gohacks diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s index b5bbfff90..74a8de42c 100644 --- a/pkg/procid/procid_amd64.s +++ b/pkg/procid/procid_amd64.s @@ -15,6 +15,10 @@ //go:build amd64 && go1.8 && !go1.18 && go1.1 // +build amd64,go1.8,!go1.18,go1.1 +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. + #include "textflag.h" TEXT ·Current(SB),NOSPLIT,$0-8 diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s index 772d96289..48182c4a9 100644 --- a/pkg/procid/procid_arm64.s +++ b/pkg/procid/procid_arm64.s @@ -15,6 +15,10 @@ //go:build arm64 && go1.8 && !go1.18 && go1.1 // +build arm64,go1.8,!go1.18,go1.1 +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. + #include "textflag.h" TEXT ·Current(SB),NOSPLIT,$0-8 diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go index 24e28a51f..22c8b7fda 100644 --- a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go +++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go @@ -383,11 +383,6 @@ func (d *dir) DecRef(ctx context.Context) { d.dirRefs.DecRef(func() { d.Destroy(ctx) }) } -// StatFS implements kernfs.Inode.StatFS. -func (d *dir) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) { - return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil -} - // controllerFile represents a generic control file that appears within a cgroup // directory. // diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index ec8d58cc9..25d2e39d6 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -1161,6 +1161,13 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs if !d.isSynthetic() { if stat.Mask != 0 { + if stat.Mask&linux.STATX_SIZE != 0 { + // d.dataMu must be held around the update to both the remote + // file's size and d.size to serialize with writeback (which + // might otherwise write data back up to the old d.size after + // the remote file has been truncated). + d.dataMu.Lock() + } if err := d.file.setAttr(ctx, p9.SetAttrMask{ Permissions: stat.Mask&linux.STATX_MODE != 0, UID: stat.Mask&linux.STATX_UID != 0, @@ -1180,13 +1187,16 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs MTimeSeconds: uint64(stat.Mtime.Sec), MTimeNanoSeconds: uint64(stat.Mtime.Nsec), }); err != nil { + if stat.Mask&linux.STATX_SIZE != 0 { + d.dataMu.Unlock() // +checklocksforce: locked conditionally above + } return err } if stat.Mask&linux.STATX_SIZE != 0 { // d.size should be kept up to date, and privatized // copy-on-write mappings of truncated pages need to be // invalidated, even if InteropModeShared is in effect. - d.updateSizeLocked(stat.Size) + d.updateSizeAndUnlockDataMuLocked(stat.Size) // +checklocksforce: locked conditionally above } } if d.fs.opts.interop == InteropModeShared { @@ -1249,6 +1259,14 @@ func (d *dentry) doAllocate(ctx context.Context, offset, length uint64, allocate // Preconditions: d.metadataMu must be locked. func (d *dentry) updateSizeLocked(newSize uint64) { d.dataMu.Lock() + d.updateSizeAndUnlockDataMuLocked(newSize) +} + +// Preconditions: d.metadataMu and d.dataMu must be locked. +// +// Postconditions: d.dataMu is unlocked. +// +checklocksrelease:d.dataMu +func (d *dentry) updateSizeAndUnlockDataMuLocked(newSize uint64) { oldSize := d.size atomic.StoreUint64(&d.size, newSize) // d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings @@ -1257,9 +1275,9 @@ func (d *dentry) updateSizeLocked(newSize uint64) { // contents beyond the new d.size. (We are still holding d.metadataMu, // so we can't race with Write or another truncate.) d.dataMu.Unlock() - if d.size < oldSize { + if newSize < oldSize { oldpgend, _ := hostarch.PageRoundUp(oldSize) - newpgend, _ := hostarch.PageRoundUp(d.size) + newpgend, _ := hostarch.PageRoundUp(newSize) if oldpgend != newpgend { d.mapsMu.Lock() d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ @@ -1275,8 +1293,8 @@ func (d *dentry) updateSizeLocked(newSize uint64) { // truncated pages have been removed from the remote file, they // should be dropped without being written back. d.dataMu.Lock() - d.cache.Truncate(d.size, d.fs.mfp.MemoryFile()) - d.dirty.KeepClean(memmap.MappableRange{d.size, oldpgend}) + d.cache.Truncate(newSize, d.fs.mfp.MemoryFile()) + d.dirty.KeepClean(memmap.MappableRange{newSize, oldpgend}) d.dataMu.Unlock() } } diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index cbbc0935a..f54811edf 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -78,7 +78,7 @@ func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns "smaps": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &smapsData{task: task}), "stat": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}), "statm": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &statmData{task: task}), - "status": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &statusData{task: task, pidns: pidns}), + "status": fs.newStatusInode(ctx, task, pidns, fs.NextIno(), 0444), "uid_map": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}), } if isThreadGroup { diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 5bb6bc372..0ce3ed797 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -661,34 +661,119 @@ func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error { return nil } -// statusData implements vfs.DynamicBytesSource for /proc/[pid]/status. +// statusInode implements kernfs.Inode for /proc/[pid]/status. // // +stateify savable -type statusData struct { - kernfs.DynamicBytesFile +type statusInode struct { + kernfs.InodeAttrs + kernfs.InodeNoStatFS + kernfs.InodeNoopRefCount + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink task *kernel.Task pidns *kernel.PIDNamespace + locks vfs.FileLocks } -var _ dynamicInode = (*statusData)(nil) +// statusFD implements vfs.FileDescriptionImpl and vfs.DynamicByteSource for +// /proc/[pid]/status. +// +// +stateify savable +type statusFD struct { + statusFDLowerBase + vfs.DynamicBytesFileDescriptionImpl + vfs.LockFD + + vfsfd vfs.FileDescription + + inode *statusInode + task *kernel.Task + pidns *kernel.PIDNamespace + userns *auth.UserNamespace // equivalent to struct file::f_cred::user_ns +} + +// statusFDLowerBase is a dumb hack to ensure that statusFD prefers +// vfs.DynamicBytesFileDescriptionImpl methods to vfs.FileDescriptinDefaultImpl +// methods. +// +// +stateify savable +type statusFDLowerBase struct { + vfs.FileDescriptionDefaultImpl +} + +func (fs *filesystem) newStatusInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, ino uint64, perm linux.FileMode) kernfs.Inode { + // Note: credentials are overridden by taskOwnedInode. + inode := &statusInode{ + task: task, + pidns: pidns, + } + inode.InodeAttrs.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeRegular|perm) + return &taskOwnedInode{Inode: inode, owner: task} +} + +// Open implements kernfs.Inode.Open. +func (s *statusInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd := &statusFD{ + inode: s, + task: s.task, + pidns: s.pidns, + userns: rp.Credentials().UserNamespace, + } + fd.LockFD.Init(&s.locks) + if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { + return nil, err + } + fd.SetDataSource(fd) + return &fd.vfsfd, nil +} + +// SetStat implements kernfs.Inode.SetStat. +func (*statusInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { + return linuxerr.EPERM +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (s *statusFD) Release(ctx context.Context) { +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (s *statusFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := s.vfsfd.VirtualDentry().Mount().Filesystem() + return s.inode.Stat(ctx, fs, opts) +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (s *statusFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { + return linuxerr.EPERM +} // Generate implements vfs.DynamicBytesSource.Generate. -func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error { +func (s *statusFD) Generate(ctx context.Context, buf *bytes.Buffer) error { fmt.Fprintf(buf, "Name:\t%s\n", s.task.Name()) fmt.Fprintf(buf, "State:\t%s\n", s.task.StateStatus()) fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.task.ThreadGroup())) fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.task)) + ppid := kernel.ThreadID(0) if parent := s.task.Parent(); parent != nil { ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup()) } fmt.Fprintf(buf, "PPid:\t%d\n", ppid) + tpid := kernel.ThreadID(0) if tracer := s.task.Tracer(); tracer != nil { tpid = s.pidns.IDOfTask(tracer) } fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid) + + creds := s.task.Credentials() + ruid := creds.RealKUID.In(s.userns).OrOverflow() + euid := creds.EffectiveKUID.In(s.userns).OrOverflow() + suid := creds.SavedKUID.In(s.userns).OrOverflow() + rgid := creds.RealKGID.In(s.userns).OrOverflow() + egid := creds.EffectiveKGID.In(s.userns).OrOverflow() + sgid := creds.SavedKGID.In(s.userns).OrOverflow() var fds int var vss, rss, data uint64 s.task.WithMuLocked(func(t *kernel.Task) { @@ -701,12 +786,26 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error { data = mm.VirtualDataSize() } }) + // Filesystem user/group IDs aren't implemented; effective UID/GID are used + // instead. + fmt.Fprintf(buf, "Uid:\t%d\t%d\t%d\t%d\n", ruid, euid, suid, euid) + fmt.Fprintf(buf, "Gid:\t%d\t%d\t%d\t%d\n", rgid, egid, sgid, egid) fmt.Fprintf(buf, "FDSize:\t%d\n", fds) + buf.WriteString("Groups:\t ") + // There is a space between each pair of supplemental GIDs, as well as an + // unconditional trailing space that some applications actually depend on. + var sep string + for _, kgid := range creds.ExtraKGIDs { + fmt.Fprintf(buf, "%s%d", sep, kgid.In(s.userns).OrOverflow()) + sep = " " + } + buf.WriteString(" \n") + fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10) fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10) fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10) + fmt.Fprintf(buf, "Threads:\t%d\n", s.task.ThreadGroup().Count()) - creds := s.task.Credentials() fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps) fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps) fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps) diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 80dda1559..b121fc1b4 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -27,6 +27,9 @@ type Stack interface { // integers. Interfaces() map[int32]Interface + // RemoveInterface removes the specified network interface. + RemoveInterface(idx int32) error + // InterfaceAddrs returns all network interface addresses as a mapping from // interface indexes to a slice of associated interface address properties. InterfaceAddrs() map[int32][]InterfaceAddr diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 218d9dafc..621f47e1f 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -45,23 +45,29 @@ func NewTestStack() *TestStack { } } -// Interfaces implements Stack.Interfaces. +// Interfaces implements Stack. func (s *TestStack) Interfaces() map[int32]Interface { return s.InterfacesMap } -// InterfaceAddrs implements Stack.InterfaceAddrs. +// RemoveInterface implements Stack. +func (s *TestStack) RemoveInterface(idx int32) error { + delete(s.InterfacesMap, idx) + return nil +} + +// InterfaceAddrs implements Stack. func (s *TestStack) InterfaceAddrs() map[int32][]InterfaceAddr { return s.InterfaceAddrsMap } -// AddInterfaceAddr implements Stack.AddInterfaceAddr. +// AddInterfaceAddr implements Stack. func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error { s.InterfaceAddrsMap[idx] = append(s.InterfaceAddrsMap[idx], addr) return nil } -// RemoveInterfaceAddr implements Stack.RemoveInterfaceAddr. +// RemoveInterfaceAddr implements Stack. func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error { interfaceAddrs, ok := s.InterfaceAddrsMap[idx] if !ok { @@ -79,94 +85,94 @@ func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error { return nil } -// SupportsIPv6 implements Stack.SupportsIPv6. +// SupportsIPv6 implements Stack. func (s *TestStack) SupportsIPv6() bool { return s.SupportsIPv6Flag } -// TCPReceiveBufferSize implements Stack.TCPReceiveBufferSize. +// TCPReceiveBufferSize implements Stack. func (s *TestStack) TCPReceiveBufferSize() (TCPBufferSize, error) { return s.TCPRecvBufSize, nil } -// SetTCPReceiveBufferSize implements Stack.SetTCPReceiveBufferSize. +// SetTCPReceiveBufferSize implements Stack. func (s *TestStack) SetTCPReceiveBufferSize(size TCPBufferSize) error { s.TCPRecvBufSize = size return nil } -// TCPSendBufferSize implements Stack.TCPSendBufferSize. +// TCPSendBufferSize implements Stack. func (s *TestStack) TCPSendBufferSize() (TCPBufferSize, error) { return s.TCPSendBufSize, nil } -// SetTCPSendBufferSize implements Stack.SetTCPSendBufferSize. +// SetTCPSendBufferSize implements Stack. func (s *TestStack) SetTCPSendBufferSize(size TCPBufferSize) error { s.TCPSendBufSize = size return nil } -// TCPSACKEnabled implements Stack.TCPSACKEnabled. +// TCPSACKEnabled implements Stack. func (s *TestStack) TCPSACKEnabled() (bool, error) { return s.TCPSACKFlag, nil } -// SetTCPSACKEnabled implements Stack.SetTCPSACKEnabled. +// SetTCPSACKEnabled implements Stack. func (s *TestStack) SetTCPSACKEnabled(enabled bool) error { s.TCPSACKFlag = enabled return nil } -// TCPRecovery implements Stack.TCPRecovery. +// TCPRecovery implements Stack. func (s *TestStack) TCPRecovery() (TCPLossRecovery, error) { return s.Recovery, nil } -// SetTCPRecovery implements Stack.SetTCPRecovery. +// SetTCPRecovery implements Stack. func (s *TestStack) SetTCPRecovery(recovery TCPLossRecovery) error { s.Recovery = recovery return nil } -// Statistics implements inet.Stack.Statistics. +// Statistics implements Stack. func (s *TestStack) Statistics(stat interface{}, arg string) error { return nil } -// RouteTable implements Stack.RouteTable. +// RouteTable implements Stack. func (s *TestStack) RouteTable() []Route { return s.RouteList } -// Resume implements Stack.Resume. +// Resume implements Stack. func (s *TestStack) Resume() {} -// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +// RegisteredEndpoints implements Stack. func (s *TestStack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } -// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +// CleanupEndpoints implements Stack. func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint { return nil } -// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +// RestoreCleanupEndpoints implements Stack. func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} -// SetForwarding implements inet.Stack.SetForwarding. +// SetForwarding implements Stack. func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { s.IPForwarding = enable return nil } -// PortRange implements inet.Stack.PortRange. +// PortRange implements Stack. func (*TestStack) PortRange() (uint16, uint16) { // Use the default Linux values per net/ipv4/af_inet.c:inet_init_net(). return 32768, 28232 } -// SetPortRange implements inet.Stack.SetPortRange. +// SetPortRange implements Stack. func (*TestStack) SetPortRange(start uint16, end uint16) error { // No-op. return nil diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index c613f4932..e4e0dc04f 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -220,6 +220,7 @@ go_library( "//pkg/abi/linux", "//pkg/abi/linux/errno", "//pkg/amutex", + "//pkg/bitmap", "//pkg/bits", "//pkg/bpf", "//pkg/cleanup", diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 8786a70b5..eff556a0c 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -18,10 +18,10 @@ import ( "fmt" "math" "strings" - "sync/atomic" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/bitmap" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -84,13 +84,8 @@ type FDTable struct { // mu protects below. mu sync.Mutex `state:"nosave"` - // next is start position to find fd. - next int32 - - // used contains the number of non-nil entries. It must be accessed - // atomically. It may be read atomically without holding mu (but not - // written). - used int32 + // fdBitmap shows which fds are already in use. + fdBitmap bitmap.Bitmap `state:"nosave"` // descriptorTable holds descriptors. descriptorTable `state:".(map[int32]descriptor)"` @@ -98,6 +93,8 @@ type FDTable struct { func (f *FDTable) saveDescriptorTable() map[int32]descriptor { m := make(map[int32]descriptor) + f.mu.Lock() + defer f.mu.Unlock() f.forEach(context.Background(), func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { m[fd] = descriptor{ file: file, @@ -111,12 +108,16 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor { func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) { ctx := context.Background() f.initNoLeakCheck() // Initialize table. - f.used = 0 + f.fdBitmap = bitmap.New(uint32(math.MaxUint16)) for fd, d := range m { + if fd < 0 { + panic(fmt.Sprintf("FD is not supposed to be negative. FD: %d", fd)) + } + if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil { panic("VFS1 or VFS2 files set") } - + f.fdBitmap.Add(uint32(fd)) // Note that we do _not_ need to acquire a extra table reference here. The // table reference will already be accounted for in the file, so we drop the // reference taken by set above. @@ -189,8 +190,10 @@ func (f *FDTable) DecRef(ctx context.Context) { func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags)) { // retries tracks the number of failed TryIncRef attempts for the same FD. retries := 0 - fd := int32(0) - for { + fds := f.fdBitmap.ToSlice() + // Iterate through the fdBitmap. + for _, ufd := range fds { + fd := int32(ufd) file, fileVFS2, flags, ok := f.getAll(fd) if !ok { break @@ -218,7 +221,6 @@ func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File, fileVFS2.DecRef(ctx) } retries = 0 - fd++ } } @@ -226,6 +228,8 @@ func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File, func (f *FDTable) String() string { var buf strings.Builder ctx := context.Background() + f.mu.Lock() + defer f.mu.Unlock() f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { switch { case file != nil: @@ -250,10 +254,10 @@ func (f *FDTable) String() string { } // NewFDs allocates new FDs guaranteed to be the lowest number available -// greater than or equal to the fd parameter. All files will share the set +// greater than or equal to the minFD parameter. All files will share the set // flags. Success is guaranteed to be all or none. -func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags FDFlags) (fds []int32, err error) { - if fd < 0 { +func (f *FDTable) NewFDs(ctx context.Context, minFD int32, files []*fs.File, flags FDFlags) (fds []int32, err error) { + if minFD < 0 { // Don't accept negative FDs. return nil, unix.EINVAL } @@ -267,31 +271,48 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags if lim.Cur != limits.Infinity { end = int32(lim.Cur) } - if fd >= end { + if minFD+int32(len(files)) > end { return nil, unix.EMFILE } } f.mu.Lock() - // From f.next to find available fd. - if fd < f.next { - fd = f.next + // max is used as the largest number in fdBitmap + 1. + max := int32(0) + + if !f.fdBitmap.IsEmpty() { + max = int32(f.fdBitmap.Maximum()) + max++ } + // Adjust max in case it is less than minFD. + if max < minFD { + max = minFD + } // Install all entries. - for i := fd; i < end && len(fds) < len(files); i++ { - if d, _, _ := f.get(i); d == nil { - // Set the descriptor. - f.set(ctx, i, files[len(fds)], flags) - fds = append(fds, i) // Record the file descriptor. + for len(fds) < len(files) { + // Try to use free bit in fdBitmap. + // If all bits in fdBitmap are used, expand fd to the max. + fd := f.fdBitmap.FirstZero(uint32(minFD)) + if fd == math.MaxInt32 { + fd = uint32(max) + max++ + } + if fd >= uint32(end) { + break } + f.fdBitmap.Add(fd) + f.set(ctx, int32(fd), files[len(fds)], flags) + fds = append(fds, int32(fd)) + minFD = int32(fd) } // Failure? Unwind existing FDs. if len(fds) < len(files) { for _, i := range fds { f.set(ctx, i, nil, FDFlags{}) + f.fdBitmap.Remove(uint32(i)) } f.mu.Unlock() @@ -305,20 +326,15 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags return nil, unix.EMFILE } - if fd == f.next { - // Update next search start position. - f.next = fds[len(fds)-1] + 1 - } - f.mu.Unlock() return fds, nil } // NewFDsVFS2 allocates new FDs guaranteed to be the lowest number available -// greater than or equal to the fd parameter. All files will share the set +// greater than or equal to the minFD parameter. All files will share the set // flags. Success is guaranteed to be all or none. -func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDescription, flags FDFlags) (fds []int32, err error) { - if fd < 0 { +func (f *FDTable) NewFDsVFS2(ctx context.Context, minFD int32, files []*vfs.FileDescription, flags FDFlags) (fds []int32, err error) { + if minFD < 0 { // Don't accept negative FDs. return nil, unix.EINVAL } @@ -332,31 +348,47 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes if lim.Cur != limits.Infinity { end = int32(lim.Cur) } - if fd >= end { + if minFD >= end { return nil, unix.EMFILE } } f.mu.Lock() - // From f.next to find available fd. - if fd < f.next { - fd = f.next + // max is used as the largest number in fdBitmap + 1. + max := int32(0) + + if !f.fdBitmap.IsEmpty() { + max = int32(f.fdBitmap.Maximum()) + max++ } - // Install all entries. - for i := fd; i < end && len(fds) < len(files); i++ { - if d, _, _ := f.getVFS2(i); d == nil { - // Set the descriptor. - f.setVFS2(ctx, i, files[len(fds)], flags) - fds = append(fds, i) // Record the file descriptor. - } + // Adjust max in case it is less than minFD. + if max < minFD { + max = minFD } + for len(fds) < len(files) { + // Try to use free bit in fdBitmap. + // If all bits in fdBitmap are used, expand fd to the max. + fd := f.fdBitmap.FirstZero(uint32(minFD)) + if fd == math.MaxInt32 { + fd = uint32(max) + max++ + } + if fd >= uint32(end) { + break + } + f.fdBitmap.Add(fd) + f.setVFS2(ctx, int32(fd), files[len(fds)], flags) + fds = append(fds, int32(fd)) + minFD = int32(fd) + } // Failure? Unwind existing FDs. if len(fds) < len(files) { for _, i := range fds { f.setVFS2(ctx, i, nil, FDFlags{}) + f.fdBitmap.Remove(uint32(i)) } f.mu.Unlock() @@ -370,57 +402,19 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes return nil, unix.EMFILE } - if fd == f.next { - // Update next search start position. - f.next = fds[len(fds)-1] + 1 - } - f.mu.Unlock() return fds, nil } -// NewFDVFS2 allocates a file descriptor greater than or equal to minfd for +// NewFDVFS2 allocates a file descriptor greater than or equal to minFD for // the given file description. If it succeeds, it takes a reference on file. -func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) { - if minfd < 0 { - // Don't accept negative FDs. - return -1, unix.EINVAL - } - - // Default limit. - end := int32(math.MaxInt32) - - // Ensure we don't get past the provided limit. - if limitSet := limits.FromContext(ctx); limitSet != nil { - lim := limitSet.Get(limits.NumberOfFiles) - if lim.Cur != limits.Infinity { - end = int32(lim.Cur) - } - if minfd >= end { - return -1, unix.EMFILE - } - } - - f.mu.Lock() - defer f.mu.Unlock() - - // From f.next to find available fd. - fd := minfd - if fd < f.next { - fd = f.next - } - for fd < end { - if d, _, _ := f.getVFS2(fd); d == nil { - f.setVFS2(ctx, fd, file, flags) - if fd == f.next { - // Update next search start position. - f.next = fd + 1 - } - return fd, nil - } - fd++ +func (f *FDTable) NewFDVFS2(ctx context.Context, minFD int32, file *vfs.FileDescription, flags FDFlags) (int32, error) { + files := []*vfs.FileDescription{file} + fileSlice, error := f.NewFDsVFS2(ctx, minFD, files, flags) + if error != nil { + return -1, error } - return -1, unix.EMFILE + return fileSlice[0], nil } // NewFDAt sets the file reference for the given FD. If there is an active @@ -469,6 +463,11 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 defer f.mu.Unlock() df, dfVFS2 := f.setAll(ctx, fd, file, fileVFS2, flags) + // Add fd to fdBitmap. + if file != nil || fileVFS2 != nil { + f.fdBitmap.Add(uint32(fd)) + } + return df, dfVFS2, nil } @@ -573,7 +572,9 @@ func (f *FDTable) GetVFS2(fd int32) (*vfs.FileDescription, FDFlags) { // Precondition: The caller must be running on the task goroutine, or Task.mu // must be locked. func (f *FDTable) GetFDs(ctx context.Context) []int32 { - fds := make([]int32, 0, int(atomic.LoadInt32(&f.used))) + f.mu.Lock() + defer f.mu.Unlock() + fds := make([]int32, 0, int(f.fdBitmap.GetNumOnes())) f.forEach(ctx, func(fd int32, _ *fs.File, _ *vfs.FileDescription, _ FDFlags) { fds = append(fds, fd) }) @@ -583,13 +584,15 @@ func (f *FDTable) GetFDs(ctx context.Context) []int32 { // Fork returns an independent FDTable. func (f *FDTable) Fork(ctx context.Context) *FDTable { clone := f.k.NewFDTable() - + f.mu.Lock() + defer f.mu.Unlock() f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { // The set function here will acquire an appropriate table // reference for the clone. We don't need anything else. if df, dfVFS2 := clone.setAll(ctx, fd, file, fileVFS2, flags); df != nil || dfVFS2 != nil { panic("VFS1 or VFS2 files set") } + clone.fdBitmap.Add(uint32(fd)) }) return clone } @@ -604,11 +607,6 @@ func (f *FDTable) Remove(ctx context.Context, fd int32) (*fs.File, *vfs.FileDesc f.mu.Lock() - // Update current available position. - if fd < f.next { - f.next = fd - } - orig, orig2, _, _ := f.getAll(fd) // Add reference for caller. @@ -621,6 +619,7 @@ func (f *FDTable) Remove(ctx context.Context, fd int32) (*fs.File, *vfs.FileDesc if orig != nil || orig2 != nil { orig, orig2 = f.setAll(ctx, fd, nil, nil, FDFlags{}) // Zap entry. + f.fdBitmap.Remove(uint32(fd)) } f.mu.Unlock() @@ -644,16 +643,13 @@ func (f *FDTable) RemoveIf(ctx context.Context, cond func(*fs.File, *vfs.FileDes f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { if cond(file, fileVFS2, flags) { df, dfVFS2 := f.setAll(ctx, fd, nil, nil, FDFlags{}) // Clear from table. + f.fdBitmap.Remove(uint32(fd)) if df != nil { files = append(files, df) } if dfVFS2 != nil { filesVFS2 = append(filesVFS2, dfVFS2) } - // Update current available position. - if fd < f.next { - f.next = fd - } } }) f.mu.Unlock() diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go index f17f9c59c..2b3e6ef71 100644 --- a/pkg/sentry/kernel/fd_table_unsafe.go +++ b/pkg/sentry/kernel/fd_table_unsafe.go @@ -15,9 +15,11 @@ package kernel import ( + "math" "sync/atomic" "unsafe" + "gvisor.dev/gvisor/pkg/bitmap" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -44,6 +46,7 @@ func (f *FDTable) initNoLeakCheck() { func (f *FDTable) init() { f.initNoLeakCheck() f.InitRefs() + f.fdBitmap = bitmap.New(uint32(math.MaxUint16)) } // get gets a file entry. @@ -162,14 +165,6 @@ func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2 } } - // Adjust used. - switch { - case orig == nil && desc != nil: - atomic.AddInt32(&f.used, 1) - case orig != nil && desc == nil: - atomic.AddInt32(&f.used, -1) - } - if orig != nil { switch { case orig.file != nil: diff --git a/pkg/sentry/kernel/msgqueue/msgqueue.go b/pkg/sentry/kernel/msgqueue/msgqueue.go index 3ce926950..c111297d7 100644 --- a/pkg/sentry/kernel/msgqueue/msgqueue.go +++ b/pkg/sentry/kernel/msgqueue/msgqueue.go @@ -119,14 +119,21 @@ type Queue struct { type Message struct { msgEntry - // mType is an integer representing the type of the sent message. - mType int64 + // Type is an integer representing the type of the sent message. + Type int64 - // mText is an untyped block of memory. - mText []byte + // Text is an untyped block of memory. + Text []byte - // mSize is the size of mText. - mSize uint64 + // Size is the size of Text. + Size uint64 +} + +// Blocker is used for blocking Queue.Send, and Queue.Receive calls that serves +// as an abstracted version of kernel.Task. kernel.Task is not directly used to +// prevent circular dependencies. +type Blocker interface { + Block(C <-chan struct{}) error } // FindOrCreate creates a new message queue or returns an existing one. See @@ -186,6 +193,265 @@ func (r *Registry) Remove(id ipc.ID, creds *auth.Credentials) error { return nil } +// FindByID returns the queue with the specified ID and an error if the ID +// doesn't exist. +func (r *Registry) FindByID(id ipc.ID) (*Queue, error) { + r.mu.Lock() + defer r.mu.Unlock() + + mech := r.reg.FindByID(id) + if mech == nil { + return nil, linuxerr.EINVAL + } + return mech.(*Queue), nil +} + +// Send appends a message to the message queue, and returns an error if sending +// fails. See msgsnd(2). +func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid int32) (err error) { + // Try to perform a non-blocking send using queue.append. If EWOULDBLOCK + // is returned, start the blocking procedure. Otherwise, return normally. + creds := auth.CredentialsFromContext(ctx) + if err := q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK { + return err + } + + if !wait { + return linuxerr.EAGAIN + } + + e, ch := waiter.NewChannelEntry(nil) + q.senders.EventRegister(&e, waiter.EventOut) + + for { + if err = q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK { + break + } + b.Block(ch) + } + + q.senders.EventUnregister(&e) + return err +} + +// append appends a message to the queue's message list and notifies waiting +// receivers that a message has been inserted. It returns an error if adding +// the message would cause the queue to exceed its maximum capacity, which can +// be used as a signal to block the task. Other errors should be returned as is. +func (q *Queue) append(ctx context.Context, m Message, creds *auth.Credentials, pid int32) error { + if m.Type <= 0 { + return linuxerr.EINVAL + } + + q.mu.Lock() + defer q.mu.Unlock() + + if !q.obj.CheckPermissions(creds, fs.PermMask{Write: true}) { + // The calling process does not have write permission on the message + // queue, and does not have the CAP_IPC_OWNER capability in the user + // namespace that governs its IPC namespace. + return linuxerr.EACCES + } + + // Queue was removed while the process was waiting. + if q.dead { + return linuxerr.EIDRM + } + + // Check if sufficient space is available (the queue isn't full.) From + // the man pages: + // + // "A message queue is considered to be full if either of the following + // conditions is true: + // + // • Adding a new message to the queue would cause the total number + // of bytes in the queue to exceed the queue's maximum size (the + // msg_qbytes field). + // + // • Adding another message to the queue would cause the total + // number of messages in the queue to exceed the queue's maximum + // size (the msg_qbytes field). This check is necessary to + // prevent an unlimited number of zero-length messages being + // placed on the queue. Although such messages contain no data, + // they nevertheless consume (locked) kernel memory." + // + // The msg_qbytes field in our implementation is q.maxBytes. + if m.Size+q.byteCount > q.maxBytes || q.messageCount+1 > q.maxBytes { + return linuxerr.EWOULDBLOCK + } + + // Copy the message into the queue. + q.messages.PushBack(&m) + + q.byteCount += m.Size + q.messageCount++ + q.sendPID = pid + q.sendTime = ktime.NowFromContext(ctx) + + // Notify receivers about the new message. + q.receivers.Notify(waiter.EventIn) + + return nil +} + +// Receive removes a message from the queue and returns it. See msgrcv(2). +func (q *Queue) Receive(ctx context.Context, b Blocker, mType int64, maxSize int64, wait, truncate, except bool, pid int32) (msg *Message, err error) { + if maxSize < 0 || maxSize > maxMessageBytes { + return nil, linuxerr.EINVAL + } + max := uint64(maxSize) + + // Try to perform a non-blocking receive using queue.pop. If EWOULDBLOCK + // is returned, start the blocking procedure. Otherwise, return normally. + creds := auth.CredentialsFromContext(ctx) + if msg, err := q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK { + return msg, err + } + + if !wait { + return nil, linuxerr.ENOMSG + } + + e, ch := waiter.NewChannelEntry(nil) + q.receivers.EventRegister(&e, waiter.EventIn) + + for { + if msg, err = q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK { + break + } + b.Block(ch) + } + q.receivers.EventUnregister(&e) + return msg, err +} + +// pop pops the first message from the queue that matches the given type. It +// returns an error for all the cases specified in msgrcv(2). If the queue is +// empty or no message of the specified type is available, a EWOULDBLOCK error +// is returned, which can then be used as a signal to block the process or fail. +func (q *Queue) pop(ctx context.Context, creds *auth.Credentials, mType int64, maxSize uint64, truncate, except bool, pid int32) (msg *Message, _ error) { + q.mu.Lock() + defer q.mu.Unlock() + + if !q.obj.CheckPermissions(creds, fs.PermMask{Read: true}) { + // The calling process does not have read permission on the message + // queue, and does not have the CAP_IPC_OWNER capability in the user + // namespace that governs its IPC namespace. + return nil, linuxerr.EACCES + } + + // Queue was removed while the process was waiting. + if q.dead { + return nil, linuxerr.EIDRM + } + + if q.messages.Empty() { + return nil, linuxerr.EWOULDBLOCK + } + + // Get a message from the queue. + switch { + case mType == 0: + msg = q.messages.Front() + case mType > 0: + msg = q.msgOfType(mType, except) + case mType < 0: + msg = q.msgOfTypeLessThan(-1 * mType) + } + + // If no message exists, return a blocking singal. + if msg == nil { + return nil, linuxerr.EWOULDBLOCK + } + + // Check message's size is acceptable. + if maxSize < msg.Size { + if !truncate { + return nil, linuxerr.E2BIG + } + msg.Size = maxSize + msg.Text = msg.Text[:maxSize+1] + } + + q.messages.Remove(msg) + + q.byteCount -= msg.Size + q.messageCount-- + q.receivePID = pid + q.receiveTime = ktime.NowFromContext(ctx) + + // Notify senders about available space. + q.senders.Notify(waiter.EventOut) + + return msg, nil +} + +// Copy copies a message from the queue without deleting it. If no message +// exists, an error is returned. See msgrcv(MSG_COPY). +func (q *Queue) Copy(mType int64) (*Message, error) { + q.mu.Lock() + defer q.mu.Unlock() + + if mType < 0 || q.messages.Empty() { + return nil, linuxerr.ENOMSG + } + + msg := q.msgAtIndex(mType) + if msg == nil { + return nil, linuxerr.ENOMSG + } + return msg, nil +} + +// msgOfType returns the first message with the specified type, nil if no +// message is found. If except is true, the first message of a type not equal +// to mType will be returned. +// +// Precondition: caller must hold q.mu. +func (q *Queue) msgOfType(mType int64, except bool) *Message { + if except { + for msg := q.messages.Front(); msg != nil; msg = msg.Next() { + if msg.Type != mType { + return msg + } + } + return nil + } + + for msg := q.messages.Front(); msg != nil; msg = msg.Next() { + if msg.Type == mType { + return msg + } + } + return nil +} + +// msgOfTypeLessThan return the the first message with the lowest type less +// than or equal to mType, nil if no such message exists. +// +// Precondition: caller must hold q.mu. +func (q *Queue) msgOfTypeLessThan(mType int64) (m *Message) { + min := mType + for msg := q.messages.Front(); msg != nil; msg = msg.Next() { + if msg.Type <= mType && msg.Type < min { + m = msg + min = msg.Type + } + } + return m +} + +// msgAtIndex returns a pointer to a message at given index, nil if non exits. +// +// Precondition: caller must hold q.mu. +func (q *Queue) msgAtIndex(mType int64) *Message { + msg := q.messages.Front() + for ; mType != 0 && msg != nil; mType-- { + msg = msg.Next() + } + return msg +} + // Lock implements ipc.Mechanism.Lock. func (q *Queue) Lock() { q.mu.Lock() diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go index 28a613a54..8fd8287b3 100644 --- a/pkg/sentry/platform/kvm/bluepill_fault.go +++ b/pkg/sentry/platform/kvm/bluepill_fault.go @@ -101,7 +101,7 @@ func handleBluepillFault(m *machine, physical uintptr, phyRegions []physicalRegi // Store the physical address in the slot. This is used to // avoid calls to handleBluepillFault in the future (see // machine.mapPhysical). - atomic.StoreUintptr(&m.usedSlots[slot], physical) + atomic.StoreUintptr(&m.usedSlots[slot], physicalStart) // Successfully added region; we can increment nextSlot and // allow another set to proceed here. atomic.StoreUint32(&m.nextSlot, slot+1) diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index f63ab6aba..0f0c1e73b 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build go1.12 && !go1.18 -// +build go1.12,!go1.18 +//go:build go1.12 +// +build go1.12 -// Check go:linkname function signatures when updating Go version. +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. package kvm @@ -28,7 +30,7 @@ import ( ) //go:linkname throw runtime.throw -func throw(string) +func throw(s string) // vCPUPtr returns a CPU for the given address. // diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 1b5d5f66e..e7092a756 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -70,7 +70,7 @@ type machine struct { // tscControl checks whether cpu supports TSC scaling tscControl bool - // usedSlots is the set of used physical addresses (sorted). + // usedSlots is the set of used physical addresses (not sorted). usedSlots []uintptr // nextID is the next vCPU ID. @@ -296,13 +296,20 @@ func newMachine(vm int) (*machine, error) { return m, nil } -// hasSlot returns true iff the given address is mapped. +// hasSlot returns true if the given address is mapped. // // This must be done via a linear scan. // //go:nosplit func (m *machine) hasSlot(physical uintptr) bool { - for i := 0; i < len(m.usedSlots); i++ { + slotLen := int(atomic.LoadUint32(&m.nextSlot)) + // When slots are being updated, nextSlot is ^uint32(0). As this situation + // is less likely happen, we just set the slotLen to m.maxSlots, and scan + // the whole usedSlots array. + if slotLen == int(^uint32(0)) { + slotLen = m.maxSlots + } + for i := 0; i < slotLen; i++ { if p := atomic.LoadUintptr(&m.usedSlots[i]); p == physical { return true } diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index 35660e827..cc3a1253b 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build go1.12 && !go1.18 -// +build go1.12,!go1.18 +//go:build go1.12 +// +build go1.12 -// Check go:linkname function signatures when updating Go version. +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. package kvm diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go index ffd4665f4..304722200 100644 --- a/pkg/sentry/platform/ptrace/subprocess_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build go1.12 && !go1.18 -// +build go1.12,!go1.18 +//go:build go1.12 +// +build go1.12 -// Check go:linkname function signatures when updating Go version. +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. package ptrace diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 7a4e78a5f..61111ac6c 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -309,6 +309,11 @@ func (s *Stack) Interfaces() map[int32]inet.Interface { return interfaces } +// RemoveInterface implements inet.Stack.RemoveInterface. +func (*Stack) RemoveInterface(int32) error { + return linuxerr.EACCES +} + // InterfaceAddrs implements inet.Stack.InterfaceAddrs. func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { addrs := make(map[int32][]inet.InterfaceAddr) @@ -319,12 +324,12 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { } // AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. -func (s *Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error { +func (*Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error { return linuxerr.EACCES } // RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr. -func (s *Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error { +func (*Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error { return linuxerr.EACCES } @@ -339,7 +344,7 @@ func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { } // SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize. -func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error { +func (*Stack) SetTCPReceiveBufferSize(inet.TCPBufferSize) error { return linuxerr.EACCES } @@ -349,7 +354,7 @@ func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { } // SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize. -func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error { +func (*Stack) SetTCPSendBufferSize(inet.TCPBufferSize) error { return linuxerr.EACCES } @@ -359,7 +364,7 @@ func (s *Stack) TCPSACKEnabled() (bool, error) { } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. -func (s *Stack) SetTCPSACKEnabled(bool) error { +func (*Stack) SetTCPSACKEnabled(bool) error { return linuxerr.EACCES } @@ -369,7 +374,7 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { } // SetTCPRecovery implements inet.Stack.SetTCPRecovery. -func (s *Stack) SetTCPRecovery(inet.TCPLossRecovery) error { +func (*Stack) SetTCPRecovery(inet.TCPLossRecovery) error { return linuxerr.EACCES } @@ -470,19 +475,19 @@ func (s *Stack) RouteTable() []inet.Route { } // Resume implements inet.Stack.Resume. -func (s *Stack) Resume() {} +func (*Stack) Resume() {} // RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. -func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } +func (*Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } // CleanupEndpoints implements inet.Stack.CleanupEndpoints. -func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } +func (*Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } // RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. -func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} +func (*Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} // SetForwarding implements inet.Stack.SetForwarding. -func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { +func (*Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { return linuxerr.EACCES } @@ -493,6 +498,6 @@ func (*Stack) PortRange() (uint16, uint16) { } // SetPortRange implements inet.Stack.SetPortRange. -func (*Stack) SetPortRange(start uint16, end uint16) error { +func (*Stack) SetPortRange(uint16, uint16) error { return linuxerr.EACCES } diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index 86f6419dc..d526acb73 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -161,6 +161,47 @@ func (p *Protocol) getLink(ctx context.Context, msg *netlink.Message, ms *netlin return nil } +// delLink handles RTM_DELLINK requests. +func (p *Protocol) delLink(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network stack. + return syserr.ErrProtocolNotSupported + } + + var ifinfomsg linux.InterfaceInfoMessage + attrs, ok := msg.GetData(&ifinfomsg) + if !ok { + return syserr.ErrInvalidArgument + } + if ifinfomsg.Index == 0 { + // The index is unspecified, search by the interface name. + ahdr, value, _, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument + } + switch ahdr.Type { + case linux.IFLA_IFNAME: + if len(value) < 1 { + return syserr.ErrInvalidArgument + } + ifname := string(value[:len(value)-1]) + for idx, ifa := range stack.Interfaces() { + if ifname == ifa.Name { + ifinfomsg.Index = idx + break + } + } + default: + return syserr.ErrInvalidArgument + } + if ifinfomsg.Index == 0 { + return syserr.ErrNoDevice + } + } + return syserr.FromError(stack.RemoveInterface(ifinfomsg.Index)) +} + // addNewLinkMessage appends RTM_NEWLINK message for the given interface into // the message set. func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { @@ -537,6 +578,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms switch hdr.Type { case linux.RTM_GETLINK: return p.getLink(ctx, msg, ms) + case linux.RTM_DELLINK: + return p.delLink(ctx, msg, ms) case linux.RTM_GETROUTE: return p.dumpRoutes(ctx, msg, ms) case linux.RTM_NEWADDR: diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 0fd0ad32c..208ab9909 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -71,6 +71,12 @@ func (s *Stack) Interfaces() map[int32]inet.Interface { return is } +// RemoveInterface implements inet.Stack.RemoveInterface. +func (s *Stack) RemoveInterface(idx int32) error { + nic := tcpip.NICID(idx) + return syserr.TranslateNetstackError(s.Stack.RemoveNIC(nic)).ToError() +} + // InterfaceAddrs implements inet.Stack.InterfaceAddrs. func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { nicAddrs := make(map[int32][]inet.InterfaceAddr) diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index ccccce6a9..b5a371d9a 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -86,6 +86,7 @@ go_library( "//pkg/sentry/kernel/eventfd", "//pkg/sentry/kernel/fasync", "//pkg/sentry/kernel/ipc", + "//pkg/sentry/kernel/msgqueue", "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/sched", "//pkg/sentry/kernel/shm", diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 6f44d767b..1ead3c7e8 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -122,8 +122,8 @@ var AMD64 = &kernel.SyscallTable{ 66: syscalls.Supported("semctl", Semctl), 67: syscalls.Supported("shmdt", Shmdt), 68: syscalls.Supported("msgget", Msgget), - 69: syscalls.ErrorWithEvent("msgsnd", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 70: syscalls.ErrorWithEvent("msgrcv", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 69: syscalls.Supported("msgsnd", Msgsnd), + 70: syscalls.Supported("msgrcv", Msgrcv), 71: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}), 72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil), 73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil), @@ -618,8 +618,8 @@ var ARM64 = &kernel.SyscallTable{ 185: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) 186: syscalls.Supported("msgget", Msgget), 187: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}), - 188: syscalls.ErrorWithEvent("msgrcv", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 189: syscalls.ErrorWithEvent("msgsnd", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 188: syscalls.Supported("msgrcv", Msgrcv), + 189: syscalls.Supported("msgsnd", Msgsnd), 190: syscalls.Supported("semget", Semget), 191: syscalls.Supported("semctl", Semctl), 192: syscalls.Supported("semtimedop", Semtimedop), diff --git a/pkg/sentry/syscalls/linux/sys_msgqueue.go b/pkg/sentry/syscalls/linux/sys_msgqueue.go index 3476e218d..5259ade90 100644 --- a/pkg/sentry/syscalls/linux/sys_msgqueue.go +++ b/pkg/sentry/syscalls/linux/sys_msgqueue.go @@ -17,10 +17,12 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/ipc" + "gvisor.dev/gvisor/pkg/sentry/kernel/msgqueue" ) // Msgget implements msgget(2). @@ -41,6 +43,89 @@ func Msgget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return uintptr(queue.ID()), nil, nil } +// Msgsnd implements msgsnd(2). +func Msgsnd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + id := ipc.ID(args[0].Int()) + msgAddr := args[1].Pointer() + size := args[2].Int64() + flag := args[3].Int() + + if size < 0 || size > linux.MSGMAX { + return 0, nil, linuxerr.EINVAL + } + + wait := flag&linux.IPC_NOWAIT != linux.IPC_NOWAIT + pid := int32(t.ThreadGroup().ID()) + + buf := linux.MsgBuf{ + Text: make([]byte, size), + } + if _, err := buf.CopyIn(t, msgAddr); err != nil { + return 0, nil, err + } + + queue, err := t.IPCNamespace().MsgqueueRegistry().FindByID(id) + if err != nil { + return 0, nil, err + } + + msg := msgqueue.Message{ + Type: int64(buf.Type), + Text: buf.Text, + Size: uint64(size), + } + return 0, nil, queue.Send(t, msg, t, wait, pid) +} + +// Msgrcv implements msgrcv(2). +func Msgrcv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + id := ipc.ID(args[0].Int()) + msgAddr := args[1].Pointer() + size := args[2].Int64() + mType := args[3].Int64() + flag := args[4].Int() + + wait := flag&linux.IPC_NOWAIT != linux.IPC_NOWAIT + except := flag&linux.MSG_EXCEPT == linux.MSG_EXCEPT + truncate := flag&linux.MSG_NOERROR == linux.MSG_NOERROR + + msgCopy := flag&linux.MSG_COPY == linux.MSG_COPY + + msg, err := receive(t, id, mType, size, msgCopy, wait, truncate, except) + if err != nil { + return 0, nil, err + } + + buf := linux.MsgBuf{ + Type: primitive.Int64(msg.Type), + Text: msg.Text, + } + if _, err := buf.CopyOut(t, msgAddr); err != nil { + return 0, nil, err + } + return uintptr(msg.Size), nil, nil +} + +// receive returns a message from the queue with the given ID. If msgCopy is +// true, a message is copied from the queue without being removed. Otherwise, +// a message is removed from the queue and returned. +func receive(t *kernel.Task, id ipc.ID, mType int64, maxSize int64, msgCopy, wait, truncate, except bool) (*msgqueue.Message, error) { + pid := int32(t.ThreadGroup().ID()) + + queue, err := t.IPCNamespace().MsgqueueRegistry().FindByID(id) + if err != nil { + return nil, err + } + + if msgCopy { + if wait || except { + return nil, linuxerr.EINVAL + } + return queue.Copy(mType) + } + return queue.Receive(t, t, mType, maxSize, wait, truncate, except, pid) +} + // Msgctl implements msgctl(2). func Msgctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { id := ipc.ID(args[0].Int()) diff --git a/pkg/sync/goyield_unsafe.go b/pkg/sync/goyield_unsafe.go index 8639bb64e..757edbaba 100644 --- a/pkg/sync/goyield_unsafe.go +++ b/pkg/sync/goyield_unsafe.go @@ -3,10 +3,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build go1.14 && !go1.18 -// +build go1.14,!go1.18 +//go:build go1.14 +// +build go1.14 -// Check go:linkname function signatures when updating Go version. +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. package sync diff --git a/pkg/sync/runtime_unsafe.go b/pkg/sync/runtime_unsafe.go index 1d9cf304e..49d4109a9 100644 --- a/pkg/sync/runtime_unsafe.go +++ b/pkg/sync/runtime_unsafe.go @@ -6,8 +6,13 @@ //go:build go1.13 && !go1.18 // +build go1.13,!go1.18 -// Check go:linkname function signatures, type definitions, and constants when -// updating Go version. +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. + +// Check type definitions and constants when updating Go version. +// +// TODO(b/165820485): add these checks to checklinkname. package sync @@ -109,10 +114,10 @@ type maptype struct { // These functions are only used within the sync package. //go:linkname semacquire sync.runtime_Semacquire -func semacquire(s *uint32) +func semacquire(addr *uint32) //go:linkname semrelease sync.runtime_Semrelease -func semrelease(s *uint32, handoff bool, skipframes int) +func semrelease(addr *uint32, handoff bool, skipframes int) //go:linkname canSpin sync.runtime_canSpin func canSpin(i int) bool diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index e8e716db0..48356c343 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -56,6 +56,7 @@ import ( // linkDispatcher reads packets from the link FD and dispatches them to the // NetworkDispatcher. type linkDispatcher interface { + stop() dispatch() (bool, tcpip.Error) } @@ -381,16 +382,27 @@ func isSocketFD(fd int) (bool, error) { // Attach launches the goroutine that reads packets from the file descriptor and // dispatches them via the provided dispatcher. func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.dispatcher = dispatcher - // Link endpoints are not savable. When transportation endpoints are - // saved, they stop sending outgoing packets and all incoming packets - // are rejected. - for i := range e.inboundDispatchers { - e.wg.Add(1) - go func(i int) { // S/R-SAFE: See above. - e.dispatchLoop(e.inboundDispatchers[i]) - e.wg.Done() - }(i) + // nil means the NIC is being removed. + if dispatcher == nil && e.dispatcher != nil { + for _, dispatcher := range e.inboundDispatchers { + dispatcher.stop() + } + e.Wait() + e.dispatcher = nil + return + } + if dispatcher != nil && e.dispatcher == nil { + e.dispatcher = dispatcher + // Link endpoints are not savable. When transportation endpoints are + // saved, they stop sending outgoing packets and all incoming packets + // are rejected. + for i := range e.inboundDispatchers { + e.wg.Add(1) + go func(i int) { // S/R-SAFE: See above. + e.dispatchLoop(e.inboundDispatchers[i]) + e.wg.Done() + }(i) + } } } diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index bfae34ab9..3f516cab5 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -114,6 +114,7 @@ func (t tPacketHdr) Payload() []byte { // packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets. // See: mmap_amd64_unsafe.go for implementation details. type packetMMapDispatcher struct { + stopFd // fd is the file descriptor used to send and receive packets. fd int @@ -129,18 +130,18 @@ type packetMMapDispatcher struct { ringOffset int } -func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) { +func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, bool, tcpip.Error) { hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:]) for hdr.tpStatus()&tpStatusUser == 0 { - event := rawfile.PollEvent{ - FD: int32(d.fd), - Events: unix.POLLIN | unix.POLLERR, - } - if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 { + stopped, errno := rawfile.BlockingPollUntilStopped(d.efd, d.fd, unix.POLLIN|unix.POLLERR) + if errno != 0 { if errno == unix.EINTR { continue } - return nil, rawfile.TranslateErrno(errno) + return nil, stopped, rawfile.TranslateErrno(errno) + } + if stopped { + return nil, true, nil } if hdr.tpStatus()&tpStatusCopy != 0 { // This frame is truncated so skip it after flipping the @@ -158,14 +159,14 @@ func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) { // Release packet to kernel. hdr.setTPStatus(tpStatusKernel) d.ringOffset = (d.ringOffset + 1) % tpFrameNR - return pkt, nil + return pkt, false, nil } // dispatch reads packets from an mmaped ring buffer and dispatches them to the // network stack. func (d *packetMMapDispatcher) dispatch() (bool, tcpip.Error) { - pkt, err := d.readMMappedPacket() - if err != nil { + pkt, stopped, err := d.readMMappedPacket() + if err != nil || stopped { return false, err } var ( diff --git a/pkg/tcpip/link/fdbased/mmap_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go index 58d5dfeef..5b786169a 100644 --- a/pkg/tcpip/link/fdbased/mmap_unsafe.go +++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go @@ -47,9 +47,14 @@ func (t tPacketHdr) setTPStatus(status uint32) { } func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + stopFd, err := newStopFd() + if err != nil { + return nil, err + } d := &packetMMapDispatcher{ - fd: fd, - e: e, + stopFd: stopFd, + fd: fd, + e: e, } pageSize := unix.Getpagesize() if tpBlockSize%pageSize != 0 { diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index ab2855a63..fab34c5fa 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -18,6 +18,8 @@ package fdbased import ( + "fmt" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -114,9 +116,36 @@ func (b *iovecBuffer) pullViews(n int) buffer.VectorisedView { return buffer.NewVectorisedView(n, views) } +// stopFd is an eventfd used to signal the stop of a dispatcher. +type stopFd struct { + efd int +} + +func newStopFd() (stopFd, error) { + efd, err := unix.Eventfd(0, unix.EFD_NONBLOCK) + if err != nil { + return stopFd{efd: -1}, fmt.Errorf("failed to create eventfd: %w", err) + } + return stopFd{efd: efd}, nil +} + +// stop writes to the eventfd and notifies the dispatcher to stop. It does not +// block. +func (s *stopFd) stop() { + increment := []byte{1, 0, 0, 0, 0, 0, 0, 0} + if n, err := unix.Write(s.efd, increment); n != len(increment) || err != nil { + // There are two possible errors documented in eventfd(2) for writing: + // 1. We are writing 8 bytes and not 0xffffffffffffff, thus no EINVAL. + // 2. stop is only supposed to be called once, it can't reach the limit, + // thus no EAGAIN. + panic(fmt.Sprintf("write(efd) = (%d, %s), want (%d, nil)", n, err, len(increment))) + } +} + // readVDispatcher uses readv() system call to read inbound packets and // dispatches them. type readVDispatcher struct { + stopFd // fd is the file descriptor used to send and receive packets. fd int @@ -128,7 +157,15 @@ type readVDispatcher struct { } func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { - d := &readVDispatcher{fd: fd, e: e} + stopFd, err := newStopFd() + if err != nil { + return nil, err + } + d := &readVDispatcher{ + stopFd: stopFd, + fd: fd, + e: e, + } skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported d.buf = newIovecBuffer(BufConfig, skipsVnetHdr) return d, nil @@ -136,8 +173,8 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { // dispatch reads one packet from the file descriptor and dispatches it. func (d *readVDispatcher) dispatch() (bool, tcpip.Error) { - n, err := rawfile.BlockingReadv(d.fd, d.buf.nextIovecs()) - if n == 0 || err != nil { + n, err := rawfile.BlockingReadvUntilStopped(d.efd, d.fd, d.buf.nextIovecs()) + if n <= 0 || err != nil { return false, err } @@ -184,6 +221,7 @@ func (d *readVDispatcher) dispatch() (bool, tcpip.Error) { // recvMMsgDispatcher uses the recvmmsg system call to read inbound packets and // dispatches them. type recvMMsgDispatcher struct { + stopFd // fd is the file descriptor used to send and receive packets. fd int @@ -207,7 +245,12 @@ const ( ) func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + stopFd, err := newStopFd() + if err != nil { + return nil, err + } d := &recvMMsgDispatcher{ + stopFd: stopFd, fd: fd, e: e, bufs: make([]*iovecBuffer, MaxMsgsPerRecv), @@ -235,8 +278,8 @@ func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) { d.msgHdrs[k].Msg.SetIovlen(iovLen) } - nMsgs, err := rawfile.BlockingRecvMMsg(d.fd, d.msgHdrs) - if err != nil { + nMsgs, err := rawfile.BlockingRecvMMsgUntilStopped(d.efd, d.fd, d.msgHdrs) + if nMsgs == -1 || err != nil { return false, err } // Process each of received packets. diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index b1a28491d..40bd5560b 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -115,6 +115,13 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc // Attach implements stack.LinkEndpoint.Attach. func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + // nil means the NIC is being removed. + if dispatcher == nil { + e.lower.Attach(nil) + e.Wait() + e.dispatcher = nil + return + } e.dispatcher = dispatcher e.lower.Attach(e) } diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index da900c24b..0b7b9e3de 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build ((linux && amd64) || (linux && arm64)) && go1.12 && !go1.18 +//go:build ((linux && amd64) || (linux && arm64)) && go1.12 // +build linux,amd64 linux,arm64 // +build go1.12 -// +build !go1.18 -// Check go:linkname function signatures when updating Go version. +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. package rawfile diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 53448a641..e76fc55b6 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -170,46 +170,63 @@ func BlockingRead(fd int, b []byte) (int, tcpip.Error) { } } -// BlockingReadv reads from a file descriptor that is set up as non-blocking and -// stores the data in a list of iovecs buffers. If no data is available, it will -// block in a poll() syscall until the file descriptor becomes readable. -func BlockingReadv(fd int, iovecs []unix.Iovec) (int, tcpip.Error) { +// BlockingReadvUntilStopped reads from a file descriptor that is set up as +// non-blocking and stores the data in a list of iovecs buffers. If no data is +// available, it will block in a poll() syscall until the file descriptor +// becomes readable or stop is signalled (efd becomes readable). Returns -1 in +// the latter case. +func BlockingReadvUntilStopped(efd int, fd int, iovecs []unix.Iovec) (int, tcpip.Error) { for { n, _, e := unix.RawSyscall(unix.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs))) if e == 0 { return int(n), nil } - event := PollEvent{ - FD: int32(fd), - Events: 1, // POLLIN + stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN) + if stopped { + return -1, nil } - - _, e = BlockingPoll(&event, 1, nil) if e != 0 && e != unix.EINTR { return 0, TranslateErrno(e) } } } -// BlockingRecvMMsg reads from a file descriptor that is set up as non-blocking -// and stores the received messages in a slice of MMsgHdr structures. If no data -// is available, it will block in a poll() syscall until the file descriptor -// becomes readable. -func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) { +// BlockingRecvMMsgUntilStopped reads from a file descriptor that is set up as +// non-blocking and stores the received messages in a slice of MMsgHdr +// structures. If no data is available, it will block in a poll() syscall until +// the file descriptor becomes readable or stop is signalled (efd becomes +// readable). Returns -1 in the latter case. +func BlockingRecvMMsgUntilStopped(efd int, fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) { for { n, _, e := unix.RawSyscall6(unix.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), unix.MSG_DONTWAIT, 0, 0) if e == 0 { return int(n), nil } - event := PollEvent{ - FD: int32(fd), - Events: 1, // POLLIN + stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN) + if stopped { + return -1, nil } - - if _, e := BlockingPoll(&event, 1, nil); e != 0 && e != unix.EINTR { + if e != 0 && e != unix.EINTR { return 0, TranslateErrno(e) } } } + +// BlockingPollUntilStopped polls for events on fd or until a stop is signalled +// on the event fd efd. Returns true if stopped, i.e., efd has event POLLIN. +func BlockingPollUntilStopped(efd int, fd int, events int16) (bool, unix.Errno) { + pevents := [...]PollEvent{ + { + FD: int32(efd), + Events: unix.POLLIN, + }, + { + FD: int32(fd), + Events: events, + }, + } + _, errno := BlockingPoll(&pevents[0], len(pevents), nil) + return pevents[0].Revents&unix.POLLIN != 0, errno +} diff --git a/pkg/tcpip/link/sniffer/pcap.go b/pkg/tcpip/link/sniffer/pcap.go index 3bb864ed2..d3edede63 100644 --- a/pkg/tcpip/link/sniffer/pcap.go +++ b/pkg/tcpip/link/sniffer/pcap.go @@ -14,7 +14,14 @@ package sniffer -import "time" +import ( + "encoding" + "encoding/binary" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) type pcapHeader struct { // MagicNumber is the file magic number. @@ -39,25 +46,38 @@ type pcapHeader struct { Network uint32 } -type pcapPacketHeader struct { - // Seconds is the timestamp seconds. - Seconds uint32 - - // Microseconds is the timestamp microseconds. - Microseconds uint32 +var _ encoding.BinaryMarshaler = (*pcapPacket)(nil) - // IncludedLength is the number of octets of packet saved in file. - IncludedLength uint32 - - // OriginalLength is the actual length of packet. - OriginalLength uint32 +type pcapPacket struct { + timestamp time.Time + packet *stack.PacketBuffer + maxCaptureLen int } -func newPCAPPacketHeader(now time.Time, incLen, orgLen uint32) pcapPacketHeader { - return pcapPacketHeader{ - Seconds: uint32(now.Unix()), - Microseconds: uint32(now.Nanosecond() / 1000), - IncludedLength: incLen, - OriginalLength: orgLen, +func (p *pcapPacket) MarshalBinary() ([]byte, error) { + packetSize := p.packet.Size() + captureLen := p.maxCaptureLen + if packetSize < captureLen { + captureLen = packetSize + } + b := make([]byte, 16+captureLen) + binary.BigEndian.PutUint32(b[0:4], uint32(p.timestamp.Unix())) + binary.BigEndian.PutUint32(b[4:8], uint32(p.timestamp.Nanosecond()/1000)) + binary.BigEndian.PutUint32(b[8:12], uint32(captureLen)) + binary.BigEndian.PutUint32(b[12:16], uint32(packetSize)) + w := tcpip.SliceWriter(b[16:]) + for _, v := range p.packet.Views() { + if captureLen == 0 { + break + } + if len(v) > captureLen { + v = v[:captureLen] + } + n, err := w.Write(v) + if err != nil { + panic(err) + } + captureLen -= n } + return b, nil } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 3df826f3c..28a172e71 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -151,33 +151,16 @@ func (e *endpoint) dumpPacket(dir direction, protocol tcpip.NetworkProtocolNumbe logPacket(e.logPrefix, dir, protocol, pkt) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { - totalLength := pkt.Size() - length := totalLength - if max := int(e.maxPCAPLen); length > max { - length = max + packet := pcapPacket{ + timestamp: time.Now(), + packet: pkt, + maxCaptureLen: int(e.maxPCAPLen), } - packetHeader := newPCAPPacketHeader(time.Now(), uint32(length), uint32(totalLength)) - packet := make([]byte, binary.Size(packetHeader)+length) - { - writer := tcpip.SliceWriter(packet) - if err := binary.Write(&writer, binary.BigEndian, packetHeader); err != nil { - panic(err) - } - for _, b := range pkt.Views() { - if length == 0 { - break - } - if len(b) > length { - b = b[:length] - } - n, err := writer.Write(b) - if err != nil { - panic(err) - } - length -= n - } + b, err := packet.MarshalBinary() + if err != nil { + panic(err) } - if _, err := writer.Write(packet); err != nil { + if _, err := writer.Write(b); err != nil { panic(err) } } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index c90974693..2257f728e 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -39,7 +39,6 @@ go_test( "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/internal/testutil", diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 4a4448cf9..73407be67 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -32,7 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" @@ -3339,7 +3338,7 @@ func TestCloseLocking(t *testing.T) { defer wg.Done() for i := 0; i < iterations; i++ { - if err := s.CreateNIC(nicID2, loopback.New()); err != nil { + if err := s.CreateNIC(nicID2, stack.LinkEndpoint(channel.New(0, defaultMTU, ""))); err != nil { t.Errorf("CreateNIC(%d, _): %s", nicID2, err) return } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 81fabe29a..c73890c4c 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -780,6 +780,9 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error { if !ok { return &tcpip.ErrUnknownNICID{} } + if nic.IsLoopback() { + return &tcpip.ErrNotSupported{} + } delete(s.nics, id) // Remove routes in-place. n tracks the number of routes written. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 21951d05a..3089c0ef4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -719,38 +719,59 @@ func TestRemoveUnknownNIC(t *testing.T) { } func TestRemoveNIC(t *testing.T) { - const nicID = 1 + for _, tt := range []struct { + name string + linkep stack.LinkEndpoint + expectErr tcpip.Error + }{ + { + name: "loopback", + linkep: loopback.New(), + expectErr: &tcpip.ErrNotSupported{}, + }, + { + name: "channel", + linkep: channel.New(0, defaultMTU, ""), + expectErr: nil, + }, + } { + t.Run(tt.name, func(t *testing.T) { + const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + }) - e := linkEPWithMockedAttach{ - LinkEndpoint: loopback.New(), - } - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + e := linkEPWithMockedAttach{ + LinkEndpoint: tt.linkep, + } + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - // NIC should be present in NICInfo and attached to a NetworkDispatcher. - allNICInfo := s.NICInfo() - if _, ok := allNICInfo[nicID]; !ok { - t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) - } - if !e.isAttached() { - t.Fatal("link endpoint not attached to a network dispatcher") - } + // NIC should be present in NICInfo and attached to a NetworkDispatcher. + allNICInfo := s.NICInfo() + if _, ok := allNICInfo[nicID]; !ok { + t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) + } + if !e.isAttached() { + t.Fatal("link endpoint not attached to a network dispatcher") + } - // Removing a NIC should remove it from NICInfo and e should be detached from - // the NetworkDispatcher. - if err := s.RemoveNIC(nicID); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) - } - if nicInfo, ok := s.NICInfo()[nicID]; ok { - t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) - } - if e.isAttached() { - t.Error("link endpoint for removed NIC still attached to a network dispatcher") + // Removing a NIC should remove it from NICInfo and e should be detached from + // the NetworkDispatcher. + if got, want := s.RemoveNIC(nicID), tt.expectErr; got != want { + t.Fatalf("got s.RemoveNIC(%d) = %s, want %s", nicID, got, want) + } + if tt.expectErr == nil { + if nicInfo, ok := s.NICInfo()[nicID]; ok { + t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) + } + if e.isAttached() { + t.Error("link endpoint for removed NIC still attached to a network dispatcher") + } + } + }) } } |