diff options
95 files changed, 3903 insertions, 509 deletions
diff --git a/debian/postinst.sh b/debian/postinst.sh index 6a326f823..b387b9f22 100755 --- a/debian/postinst.sh +++ b/debian/postinst.sh @@ -22,7 +22,7 @@ fi if [ -f /etc/docker/daemon.json ]; then runsc install if systemctl is-active -q docker; then - systemctl restart docker || echo "unable to restart docker; you must do so manually." >&2 + systemctl reload docker || echo "unable to reload docker; you must do so manually." >&2 fi fi diff --git a/images/default/Dockerfile b/images/default/Dockerfile index 5f652f2c3..4384d6271 100644 --- a/images/default/Dockerfile +++ b/images/default/Dockerfile @@ -15,7 +15,7 @@ RUN add-apt-repository \ "deb https://download.docker.com/linux/ubuntu \ $(lsb_release -cs) \ stable" -RUN apt-get install docker-ce-cli +RUN apt-get -y install docker-ce-cli # Install gcloud. RUN curl https://dl.google.com/dl/cloudsdk/channels/rapid/downloads/google-cloud-sdk-289.0.0-linux-x86_64.tar.gz | \ @@ -150,6 +150,17 @@ analyzers: external: # Enabled. checkescape: external: # Enabled. + checklinkname: + external: # Enabled. + suppress: + # We don't care to check every single linkname in the Go standard + # library. Suppress findings about stdlib linkname targets we haven't + # described in checklinkname. + # + # Note that we _do_ want to check the signature of the known linkname + # targets in the standard library, so we still need to run + # checklinkname on stdlib generally. + - "linkname to unknown symbol" SA4016: internal: exclude: diff --git a/pkg/bitmap/BUILD b/pkg/bitmap/BUILD new file mode 100644 index 000000000..0f1e7006d --- /dev/null +++ b/pkg/bitmap/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "bitmap", + srcs = ["bitmap.go"], + visibility = ["//:sandbox"], +) + +go_test( + name = "bitmap_test", + size = "small", + srcs = ["bitmap_test.go"], + library = ":bitmap", +) diff --git a/pkg/bitmap/bitmap.go b/pkg/bitmap/bitmap.go new file mode 100644 index 000000000..12d2fc2b8 --- /dev/null +++ b/pkg/bitmap/bitmap.go @@ -0,0 +1,234 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package bitmap provides the implementation of bitmap. +package bitmap + +import ( + "math" + "math/bits" +) + +// Bitmap implements an efficient bitmap. +// +// +stateify savable +type Bitmap struct { + // numOnes is the number of ones in the bitmap. + numOnes uint32 + + // bitBlock holds the bits. The type of bitBlock is uint64 which means + // each number in bitBlock contains 64 entries. + bitBlock []uint64 +} + +// New create a new empty Bitmap. +func New(size uint32) Bitmap { + b := Bitmap{} + bSize := (size + 63) / 64 + b.bitBlock = make([]uint64, bSize) + return b +} + +// IsEmpty verifies whether the Bitmap is empty. +func (b *Bitmap) IsEmpty() bool { + return b.numOnes == 0 +} + +// Minimum return the smallest value in the Bitmap. +func (b *Bitmap) Minimum() uint32 { + for i := 0; i < len(b.bitBlock); i++ { + if w := b.bitBlock[i]; w != 0 { + r := bits.TrailingZeros64(w) + return uint32(r + i*64) + } + } + return math.MaxInt32 +} + +// FirstZero returns the first unset bit from the range [start, ). +func (b *Bitmap) FirstZero(start uint32) uint32 { + i, nbit := int(start/64), start%64 + n := len(b.bitBlock) + if i >= n { + return math.MaxInt32 + } + w := b.bitBlock[i] | ((1 << nbit) - 1) + for { + if w != ^uint64(0) { + r := bits.TrailingZeros64(^w) + return uint32(r + i*64) + } + i++ + if i == n { + break + } + w = b.bitBlock[i] + } + return math.MaxInt32 +} + +// Maximum return the largest value in the Bitmap. +func (b *Bitmap) Maximum() uint32 { + for i := len(b.bitBlock) - 1; i >= 0; i-- { + if w := b.bitBlock[i]; w != 0 { + r := bits.LeadingZeros64(w) + return uint32(i*64 + 63 - r) + } + } + return uint32(0) +} + +// Add add i to the Bitmap. +func (b *Bitmap) Add(i uint32) { + blockNum, mask := i/64, uint64(1)<<(i%64) + // if blockNum is out of range, extend b.bitBlock + if x, y := int(blockNum), len(b.bitBlock); x >= y { + b.bitBlock = append(b.bitBlock, make([]uint64, x-y+1)...) + } + oldBlock := b.bitBlock[blockNum] + newBlock := oldBlock | mask + if oldBlock != newBlock { + b.bitBlock[blockNum] = newBlock + b.numOnes++ + } +} + +// Remove i from the Bitmap. +func (b *Bitmap) Remove(i uint32) { + blockNum, mask := i/64, uint64(1)<<(i%64) + oldBlock := b.bitBlock[blockNum] + newBlock := oldBlock &^ mask + if oldBlock != newBlock { + b.bitBlock[blockNum] = newBlock + b.numOnes-- + } +} + +// Clone the Bitmap. +func (b *Bitmap) Clone() Bitmap { + bitmap := Bitmap{b.numOnes, make([]uint64, len(b.bitBlock))} + copy(bitmap.bitBlock, b.bitBlock[:]) + return bitmap +} + +// countOnesForBlocks count all 1 bits within b.bitBlock of begin and that of end. +// The begin block and end block are inclusive. +func (b *Bitmap) countOnesForBlocks(begin, end uint32) uint64 { + ones := uint64(0) + beginBlock := begin / 64 + endBlock := end / 64 + for i := beginBlock; i <= endBlock; i++ { + ones += uint64(bits.OnesCount64(b.bitBlock[i])) + } + return ones +} + +// countOnesForAllBlocks count all 1 bits in b.bitBlock. +func (b *Bitmap) countOnesForAllBlocks() uint64 { + ones := uint64(0) + for i := 0; i < len(b.bitBlock); i++ { + ones += uint64(bits.OnesCount64(b.bitBlock[i])) + } + return ones +} + +// flipRange flip the bits within range (begin and end). begin is inclusive and end is exclusive. +func (b *Bitmap) flipRange(begin, end uint32) { + end-- + beginBlock := begin / 64 + endBlock := end / 64 + if beginBlock == endBlock { + b.bitBlock[endBlock] ^= ((^uint64(0) << uint(begin%64)) & ((uint64(1) << (uint(end)%64 + 1)) - 1)) + } else { + b.bitBlock[beginBlock] ^= ^(^uint64(0) << uint(begin%64)) + for i := beginBlock; i < endBlock; i++ { + b.bitBlock[i] = ^b.bitBlock[i] + } + b.bitBlock[endBlock] ^= ((uint64(1) << (uint(end)%64 + 1)) - 1) + } +} + +// clearRange clear the bits within range (begin and end). begin is inclusive and end is exclusive. +func (b *Bitmap) clearRange(begin, end uint32) { + end-- + beginBlock := begin / 64 + endBlock := end / 64 + if beginBlock == endBlock { + b.bitBlock[beginBlock] &= (((uint64(1) << uint(begin%64)) - 1) | ^((uint64(1) << (uint(end)%64 + 1)) - 1)) + } else { + b.bitBlock[beginBlock] &= ((uint64(1) << uint(begin%64)) - 1) + for i := beginBlock + 1; i < endBlock; i++ { + b.bitBlock[i] &= ^b.bitBlock[i] + } + b.bitBlock[endBlock] &= ^((uint64(1) << (uint(end)%64 + 1)) - 1) + } +} + +// ClearRange clear bits within range (begin and end) for the Bitmap. begin is inclusive and end is exclusive. +func (b *Bitmap) ClearRange(begin, end uint32) { + blockRange := end/64 - begin/64 + // When the number of cleared blocks is larger than half of the length of b.bitBlock, + // counting 1s for the entire bitmap has better performance. + if blockRange > uint32(len(b.bitBlock)/2) { + b.clearRange(begin, end) + b.numOnes = uint32(b.countOnesForAllBlocks()) + } else { + oldRangeOnes := b.countOnesForBlocks(begin, end) + b.clearRange(begin, end) + newRangeOnes := b.countOnesForBlocks(begin, end) + b.numOnes += uint32(newRangeOnes - oldRangeOnes) + } +} + +// FlipRange flip bits within range (begin and end) for the Bitmap. begin is inclusive and end is exclusive. +func (b *Bitmap) FlipRange(begin, end uint32) { + blockRange := end/64 - begin/64 + // When the number of flipped blocks is larger than half of the length of b.bitBlock, + // counting 1s for the entire bitmap has better performance. + if blockRange > uint32(len(b.bitBlock)/2) { + b.flipRange(begin, end) + b.numOnes = uint32(b.countOnesForAllBlocks()) + } else { + oldRangeOnes := b.countOnesForBlocks(begin, end) + b.flipRange(begin, end) + newRangeOnes := b.countOnesForBlocks(begin, end) + b.numOnes += uint32(newRangeOnes - oldRangeOnes) + } +} + +// ToSlice transform the Bitmap into slice. For example, a bitmap of [0, 1, 0, 1] +// will return the slice [1, 3]. +func (b *Bitmap) ToSlice() []uint32 { + bitmapSlice := make([]uint32, 0, b.numOnes) + // base is the start number of a bitBlock + base := 0 + for i := 0; i < len(b.bitBlock); i++ { + bitBlock := b.bitBlock[i] + // Iterate through all the numbers held by this bit block. + for bitBlock != 0 { + // Extract the lowest set 1 bit. + j := bitBlock & -bitBlock + // Interpret the bit as the in32 number it represents and add it to result. + bitmapSlice = append(bitmapSlice, uint32((base + int(bits.OnesCount64(j-1))))) + bitBlock ^= j + } + base += 64 + } + return bitmapSlice +} + +// GetNumOnes return the the number of ones in the Bitmap. +func (b *Bitmap) GetNumOnes() uint32 { + return b.numOnes +} diff --git a/pkg/bitmap/bitmap_test.go b/pkg/bitmap/bitmap_test.go new file mode 100644 index 000000000..76ebd779f --- /dev/null +++ b/pkg/bitmap/bitmap_test.go @@ -0,0 +1,306 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package bitmap + +import ( + "math" + "reflect" + "testing" +) + +// generateFilledSlice generates a slice fill with numbers. +func generateFilledSlice(min, max, length int) []uint32 { + if max == min { + return []uint32{uint32(min)} + } + if length > (max - min) { + return nil + } + randSlice := make([]uint32, length) + if length != 0 { + rangeNum := uint32((max - min) / length) + randSlice[0], randSlice[length-1] = uint32(min), uint32(max) + for i := 1; i < length-1; i++ { + randSlice[i] = randSlice[i-1] + rangeNum + } + } + return randSlice +} + +// generateFilledBitmap generates a Bitmap filled with fillNum of numbers, +// and returns the slice and bitmap. +func generateFilledBitmap(min, max, fillNum int) ([]uint32, Bitmap) { + bitmap := New(uint32(max)) + randSlice := generateFilledSlice(min, max, fillNum) + for i := 0; i < fillNum; i++ { + bitmap.Add(randSlice[i]) + } + return randSlice, bitmap +} + +func TestNewBitmap(t *testing.T) { + tests := []struct { + name string + size int + expectSize int + }{ + {"length 1", 1, 1}, + {"length 1024", 1024, 16}, + {"length 1025", 1025, 17}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + if bitmap := New(uint32(tt.size)); len(bitmap.bitBlock) != tt.expectSize { + t.Errorf("New created bitmap with %v, bitBlock size: %d, wanted: %d", tt.name, len(bitmap.bitBlock), tt.expectSize) + } + }) + } +} + +func TestAdd(t *testing.T) { + tests := []struct { + name string + bitmapSize int + addNum int + }{ + {"Add with null bitmap.bitBlock", 0, 10}, + {"Add without extending bitBlock", 64, 10}, + {"Add without extending bitblock with margin number", 63, 64}, + {"Add with extended one block", 1024, 1025}, + {"Add with extended more then one block", 1024, 2048}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + bitmap := New(uint32(tt.bitmapSize)) + bitmap.Add(uint32(tt.addNum)) + bitmapSlice := bitmap.ToSlice() + if bitmapSlice[0] != uint32(tt.addNum) { + t.Errorf("%v, get number: %d, wanted: %d.", tt.name, bitmapSlice[0], tt.addNum) + } + }) + } +} + +func TestRemove(t *testing.T) { + bitmap := New(uint32(1024)) + firstSlice := generateFilledSlice(0, 511, 50) + secondSlice := generateFilledSlice(512, 1024, 50) + for i := 0; i < 50; i++ { + bitmap.Add(firstSlice[i]) + bitmap.Add(secondSlice[i]) + } + + for i := 0; i < 50; i++ { + bitmap.Remove(firstSlice[i]) + } + bitmapSlice := bitmap.ToSlice() + if !reflect.DeepEqual(bitmapSlice, secondSlice) { + t.Errorf("After Remove() firstSlice, remained slice: %v, wanted: %v", bitmapSlice, secondSlice) + } + + for i := 0; i < 50; i++ { + bitmap.Remove(secondSlice[i]) + } + bitmapSlice = bitmap.ToSlice() + emptySlice := make([]uint32, 0) + if !reflect.DeepEqual(bitmapSlice, emptySlice) { + t.Errorf("After Remove secondSlice, remained slice: %v, wanted: %v", bitmapSlice, emptySlice) + } + +} + +// Verify flip bits within one bitBlock, one bit and bits cross multi bitBlocks. +func TestFlipRange(t *testing.T) { + tests := []struct { + name string + flipRangeMin int + flipRangeMax int + filledSliceLen int + }{ + {"Flip one number in bitmap", 77, 77, 1}, + {"Flip numbers within one bitBlocks", 5, 60, 20}, + {"Flip numbers that cross multi bitBlocks", 20, 1000, 300}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + fillSlice, bitmap := generateFilledBitmap(tt.flipRangeMin, tt.flipRangeMax, tt.filledSliceLen) + flipFillSlice := make([]uint32, 0) + for i, j := tt.flipRangeMin, 0; i <= tt.flipRangeMax; i++ { + if uint32(i) != fillSlice[j] { + flipFillSlice = append(flipFillSlice, uint32(i)) + } else { + j++ + } + } + + bitmap.FlipRange(uint32(tt.flipRangeMin), uint32(tt.flipRangeMax+1)) + flipBitmapSlice := bitmap.ToSlice() + if !reflect.DeepEqual(flipFillSlice, flipBitmapSlice) { + t.Errorf("%v, flipped slice: %v, wanted: %v", tt.name, flipBitmapSlice, flipFillSlice) + } + }) + } +} + +// Verify clear bits within one bitBlock, one bit and bits cross multi bitBlocks. +func TestClearRange(t *testing.T) { + tests := []struct { + name string + clearRangeMin int + clearRangeMax int + bitmapSize int + }{ + {"ClearRange clear one number", 5, 5, 64}, + {"ClearRange clear numbers within one bitBlock", 4, 61, 64}, + {"ClearRange clear numbers cross multi bitBlocks", 20, 254, 512}, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + bitmap := New(uint32(tt.bitmapSize)) + bitmap.FlipRange(uint32(0), uint32(tt.bitmapSize)) + bitmap.ClearRange(uint32(tt.clearRangeMin), uint32(tt.clearRangeMax+1)) + clearedBitmapSlice := bitmap.ToSlice() + clearedSlice := make([]uint32, 0) + for i := 0; i < tt.bitmapSize; i++ { + if i < tt.clearRangeMin || i > tt.clearRangeMax { + clearedSlice = append(clearedSlice, uint32(i)) + } + } + if !reflect.DeepEqual(clearedSlice, clearedBitmapSlice) { + t.Errorf("%v, cleared slice: %v, wanted: %v", tt.name, clearedBitmapSlice, clearedSlice) + } + }) + + } +} + +func TestMinimum(t *testing.T) { + randSlice, bitmap := generateFilledBitmap(0, 1024, 200) + min := bitmap.Minimum() + if min != randSlice[0] { + t.Errorf("Minimum() returns: %v, wanted: %v", min, randSlice[0]) + } + + bitmap.ClearRange(uint32(0), uint32(200)) + min = bitmap.Minimum() + bitmapSlice := bitmap.ToSlice() + if min != bitmapSlice[0] { + t.Errorf("After ClearRange, Minimum() returns: %v, wanted: %v", min, bitmapSlice[0]) + } + + bitmap.FlipRange(uint32(2), uint32(11)) + min = bitmap.Minimum() + bitmapSlice = bitmap.ToSlice() + if min != bitmapSlice[0] { + t.Errorf("After Flip, Minimum() returns: %v, wanted: %v", min, bitmapSlice[0]) + } +} + +func TestMaximum(t *testing.T) { + randSlice, bitmap := generateFilledBitmap(0, 1024, 200) + max := bitmap.Maximum() + if max != randSlice[len(randSlice)-1] { + t.Errorf("Maximum() returns: %v, wanted: %v", max, randSlice[len(randSlice)-1]) + } + + bitmap.ClearRange(uint32(1000), uint32(1025)) + max = bitmap.Maximum() + bitmapSlice := bitmap.ToSlice() + if max != bitmapSlice[len(bitmapSlice)-1] { + t.Errorf("After ClearRange, Maximum() returns: %v, wanted: %v", max, bitmapSlice[len(bitmapSlice)-1]) + } + + bitmap.FlipRange(uint32(1001), uint32(1021)) + max = bitmap.Maximum() + bitmapSlice = bitmap.ToSlice() + if max != bitmapSlice[len(bitmapSlice)-1] { + t.Errorf("After Flip, Maximum() returns: %v, wanted: %v", max, bitmapSlice[len(bitmapSlice)-1]) + } +} + +func TestBitmapNumOnes(t *testing.T) { + randSlice, bitmap := generateFilledBitmap(0, 1024, 200) + bitmapOnes := bitmap.GetNumOnes() + if bitmapOnes != uint32(200) { + t.Errorf("GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 200) + } + // Remove 10 numbers. + for i := 5; i < 15; i++ { + bitmap.Remove(randSlice[i]) + } + bitmapOnes = bitmap.GetNumOnes() + if bitmapOnes != uint32(190) { + t.Errorf("After Remove 10 number, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 190) + } + // Remove the 10 number again, the length supposed not change. + for i := 5; i < 15; i++ { + bitmap.Remove(randSlice[i]) + } + bitmapOnes = bitmap.GetNumOnes() + if bitmapOnes != uint32(190) { + t.Errorf("After Remove the 10 number again, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 190) + } + + // Add 10 number. + for i := 1080; i < 1090; i++ { + bitmap.Add(uint32(i)) + } + bitmapOnes = bitmap.GetNumOnes() + if bitmapOnes != uint32(200) { + t.Errorf("After Add 10 number, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 200) + } + + // Add the 10 number again, the length supposed not change. + for i := 1080; i < 1090; i++ { + bitmap.Add(uint32(i)) + } + bitmapOnes = bitmap.GetNumOnes() + if bitmapOnes != uint32(200) { + t.Errorf("After Add the 10 number again, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 200) + } + + // Flip 10 bits from 0 to 1. + bitmap.FlipRange(uint32(1025), uint32(1035)) + bitmapOnes = bitmap.GetNumOnes() + if bitmapOnes != uint32(210) { + t.Errorf("After Flip, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 210) + } + + // ClearRange numbers range from [0, 1025). + bitmap.ClearRange(uint32(0), uint32(1025)) + bitmapOnes = bitmap.GetNumOnes() + if bitmapOnes != uint32(20) { + t.Errorf("After ClearRange, GetNumOnes() returns: %v, wanted: %v", bitmapOnes, 20) + } +} + +func TestFirstZero(t *testing.T) { + bitmap := New(uint32(1000)) + bitmap.FlipRange(200, 400) + for i, j := range map[uint32]uint32{0: 0, 201: 400, 200: 400, 199: 199, 400: 400, 10000: math.MaxInt32} { + v := bitmap.FirstZero(i) + if v != j { + t.Errorf("Minimum() returns: %v, wanted: %v", v, j) + } + } +} diff --git a/pkg/gohacks/gohacks_unsafe.go b/pkg/gohacks/gohacks_unsafe.go index 09fc14787..bd8ceba19 100644 --- a/pkg/gohacks/gohacks_unsafe.go +++ b/pkg/gohacks/gohacks_unsafe.go @@ -15,7 +15,13 @@ //go:build go1.13 && !go1.18 // +build go1.13,!go1.18 -// Check type signatures when updating Go version. +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. + +// Check type signatures and Noescape when updating Go version. +// +// TODO(b/165820485): add these checks to checklinkname. // Package gohacks contains utilities for subverting the Go compiler. package gohacks diff --git a/pkg/procid/procid_amd64.s b/pkg/procid/procid_amd64.s index b5bbfff90..74a8de42c 100644 --- a/pkg/procid/procid_amd64.s +++ b/pkg/procid/procid_amd64.s @@ -15,6 +15,10 @@ //go:build amd64 && go1.8 && !go1.18 && go1.1 // +build amd64,go1.8,!go1.18,go1.1 +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. + #include "textflag.h" TEXT ·Current(SB),NOSPLIT,$0-8 diff --git a/pkg/procid/procid_arm64.s b/pkg/procid/procid_arm64.s index 772d96289..48182c4a9 100644 --- a/pkg/procid/procid_arm64.s +++ b/pkg/procid/procid_arm64.s @@ -15,6 +15,10 @@ //go:build arm64 && go1.8 && !go1.18 && go1.1 // +build arm64,go1.8,!go1.18,go1.1 +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. + #include "textflag.h" TEXT ·Current(SB),NOSPLIT,$0-8 diff --git a/pkg/sentry/control/BUILD b/pkg/sentry/control/BUILD index deaf5fa23..7ee237c9f 100644 --- a/pkg/sentry/control/BUILD +++ b/pkg/sentry/control/BUILD @@ -6,6 +6,8 @@ go_library( name = "control", srcs = [ "control.go", + "fs.go", + "lifecycle.go", "logging.go", "pprof.go", "proc.go", @@ -16,6 +18,7 @@ go_library( ], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/fd", "//pkg/log", "//pkg/sentry/fdimport", @@ -35,6 +38,7 @@ go_library( "//pkg/sync", "//pkg/tcpip/link/sniffer", "//pkg/urpc", + "//pkg/usermem", ], ) diff --git a/pkg/sentry/control/fs.go b/pkg/sentry/control/fs.go new file mode 100644 index 000000000..d19b21f2d --- /dev/null +++ b/pkg/sentry/control/fs.go @@ -0,0 +1,93 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package control + +import ( + "fmt" + "io" + "os" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/fs" + "gvisor.dev/gvisor/pkg/sentry/kernel" + "gvisor.dev/gvisor/pkg/urpc" + "gvisor.dev/gvisor/pkg/usermem" +) + +// CatOpts contains options for the Cat RPC call. +type CatOpts struct { + // Files are the filesystem paths for the files to cat. + Files []string `json:"files"` + + // FilePayload contains the destination for output. + urpc.FilePayload +} + +// Fs includes fs-related functions. +type Fs struct { + Kernel *kernel.Kernel +} + +// Cat is a RPC stub which prints out and returns the content of the files. +func (f *Fs) Cat(o *CatOpts, _ *struct{}) error { + // Create an output stream. + if len(o.FilePayload.Files) != 1 { + return ErrInvalidFiles + } + + output := o.FilePayload.Files[0] + for _, file := range o.Files { + if err := cat(f.Kernel, file, output); err != nil { + return fmt.Errorf("cannot read from file %s: %v", file, err) + } + } + + return nil +} + +// fileReader encapsulates a fs.File and provides an io.Reader interface. +type fileReader struct { + ctx context.Context + file *fs.File +} + +// Read implements io.Reader.Read. +func (f *fileReader) Read(p []byte) (int, error) { + n, err := f.file.Readv(f.ctx, usermem.BytesIOSequence(p)) + return int(n), err +} + +func cat(k *kernel.Kernel, path string, output *os.File) error { + ctx := k.SupervisorContext() + mns := k.GlobalInit().Leader().MountNamespace() + root := mns.Root() + defer root.DecRef(ctx) + + remainingTraversals := uint(fs.DefaultTraversalLimit) + d, err := mns.FindInode(ctx, root, nil, path, &remainingTraversals) + if err != nil { + return fmt.Errorf("cannot find file %s: %v", path, err) + } + defer d.DecRef(ctx) + + file, err := d.Inode.GetFile(ctx, d, fs.FileFlags{Read: true}) + if err != nil { + return fmt.Errorf("cannot get file for path %s: %v", path, err) + } + defer file.DecRef(ctx) + + _, err = io.Copy(output, &fileReader{ctx: ctx, file: file}) + return err +} diff --git a/pkg/sentry/control/lifecycle.go b/pkg/sentry/control/lifecycle.go new file mode 100644 index 000000000..67abf497d --- /dev/null +++ b/pkg/sentry/control/lifecycle.go @@ -0,0 +1,36 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package control + +import ( + "gvisor.dev/gvisor/pkg/sentry/kernel" +) + +// Lifecycle provides functions related to starting and stopping tasks. +type Lifecycle struct { + Kernel *kernel.Kernel +} + +// Pause pauses all tasks, blocking until they are stopped. +func (l *Lifecycle) Pause(_, _ *struct{}) error { + l.Kernel.Pause() + return nil +} + +// Resume resumes all tasks. +func (l *Lifecycle) Resume(_, _ *struct{}) error { + l.Kernel.Unpause() + return nil +} diff --git a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go index 24e28a51f..22c8b7fda 100644 --- a/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go +++ b/pkg/sentry/fsimpl/cgroupfs/cgroupfs.go @@ -383,11 +383,6 @@ func (d *dir) DecRef(ctx context.Context) { d.dirRefs.DecRef(func() { d.Destroy(ctx) }) } -// StatFS implements kernfs.Inode.StatFS. -func (d *dir) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, error) { - return vfs.GenericStatFS(linux.CGROUP_SUPER_MAGIC), nil -} - // controllerFile represents a generic control file that appears within a cgroup // directory. // diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index ec8d58cc9..25d2e39d6 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -1161,6 +1161,13 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs if !d.isSynthetic() { if stat.Mask != 0 { + if stat.Mask&linux.STATX_SIZE != 0 { + // d.dataMu must be held around the update to both the remote + // file's size and d.size to serialize with writeback (which + // might otherwise write data back up to the old d.size after + // the remote file has been truncated). + d.dataMu.Lock() + } if err := d.file.setAttr(ctx, p9.SetAttrMask{ Permissions: stat.Mask&linux.STATX_MODE != 0, UID: stat.Mask&linux.STATX_UID != 0, @@ -1180,13 +1187,16 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs MTimeSeconds: uint64(stat.Mtime.Sec), MTimeNanoSeconds: uint64(stat.Mtime.Nsec), }); err != nil { + if stat.Mask&linux.STATX_SIZE != 0 { + d.dataMu.Unlock() // +checklocksforce: locked conditionally above + } return err } if stat.Mask&linux.STATX_SIZE != 0 { // d.size should be kept up to date, and privatized // copy-on-write mappings of truncated pages need to be // invalidated, even if InteropModeShared is in effect. - d.updateSizeLocked(stat.Size) + d.updateSizeAndUnlockDataMuLocked(stat.Size) // +checklocksforce: locked conditionally above } } if d.fs.opts.interop == InteropModeShared { @@ -1249,6 +1259,14 @@ func (d *dentry) doAllocate(ctx context.Context, offset, length uint64, allocate // Preconditions: d.metadataMu must be locked. func (d *dentry) updateSizeLocked(newSize uint64) { d.dataMu.Lock() + d.updateSizeAndUnlockDataMuLocked(newSize) +} + +// Preconditions: d.metadataMu and d.dataMu must be locked. +// +// Postconditions: d.dataMu is unlocked. +// +checklocksrelease:d.dataMu +func (d *dentry) updateSizeAndUnlockDataMuLocked(newSize uint64) { oldSize := d.size atomic.StoreUint64(&d.size, newSize) // d.dataMu must be unlocked to lock d.mapsMu and invalidate mappings @@ -1257,9 +1275,9 @@ func (d *dentry) updateSizeLocked(newSize uint64) { // contents beyond the new d.size. (We are still holding d.metadataMu, // so we can't race with Write or another truncate.) d.dataMu.Unlock() - if d.size < oldSize { + if newSize < oldSize { oldpgend, _ := hostarch.PageRoundUp(oldSize) - newpgend, _ := hostarch.PageRoundUp(d.size) + newpgend, _ := hostarch.PageRoundUp(newSize) if oldpgend != newpgend { d.mapsMu.Lock() d.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ @@ -1275,8 +1293,8 @@ func (d *dentry) updateSizeLocked(newSize uint64) { // truncated pages have been removed from the remote file, they // should be dropped without being written back. d.dataMu.Lock() - d.cache.Truncate(d.size, d.fs.mfp.MemoryFile()) - d.dirty.KeepClean(memmap.MappableRange{d.size, oldpgend}) + d.cache.Truncate(newSize, d.fs.mfp.MemoryFile()) + d.dirty.KeepClean(memmap.MappableRange{newSize, oldpgend}) d.dataMu.Unlock() } } diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index cbbc0935a..f54811edf 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -78,7 +78,7 @@ func (fs *filesystem) newTaskInode(ctx context.Context, task *kernel.Task, pidns "smaps": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &smapsData{task: task}), "stat": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &taskStatData{task: task, pidns: pidns, tgstats: isThreadGroup}), "statm": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &statmData{task: task}), - "status": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0444, &statusData{task: task, pidns: pidns}), + "status": fs.newStatusInode(ctx, task, pidns, fs.NextIno(), 0444), "uid_map": fs.newTaskOwnedInode(ctx, task, fs.NextIno(), 0644, &idMapData{task: task, gids: false}), } if isThreadGroup { diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 5bb6bc372..0ce3ed797 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -661,34 +661,119 @@ func (s *statmData) Generate(ctx context.Context, buf *bytes.Buffer) error { return nil } -// statusData implements vfs.DynamicBytesSource for /proc/[pid]/status. +// statusInode implements kernfs.Inode for /proc/[pid]/status. // // +stateify savable -type statusData struct { - kernfs.DynamicBytesFile +type statusInode struct { + kernfs.InodeAttrs + kernfs.InodeNoStatFS + kernfs.InodeNoopRefCount + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink task *kernel.Task pidns *kernel.PIDNamespace + locks vfs.FileLocks } -var _ dynamicInode = (*statusData)(nil) +// statusFD implements vfs.FileDescriptionImpl and vfs.DynamicByteSource for +// /proc/[pid]/status. +// +// +stateify savable +type statusFD struct { + statusFDLowerBase + vfs.DynamicBytesFileDescriptionImpl + vfs.LockFD + + vfsfd vfs.FileDescription + + inode *statusInode + task *kernel.Task + pidns *kernel.PIDNamespace + userns *auth.UserNamespace // equivalent to struct file::f_cred::user_ns +} + +// statusFDLowerBase is a dumb hack to ensure that statusFD prefers +// vfs.DynamicBytesFileDescriptionImpl methods to vfs.FileDescriptinDefaultImpl +// methods. +// +// +stateify savable +type statusFDLowerBase struct { + vfs.FileDescriptionDefaultImpl +} + +func (fs *filesystem) newStatusInode(ctx context.Context, task *kernel.Task, pidns *kernel.PIDNamespace, ino uint64, perm linux.FileMode) kernfs.Inode { + // Note: credentials are overridden by taskOwnedInode. + inode := &statusInode{ + task: task, + pidns: pidns, + } + inode.InodeAttrs.Init(ctx, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeRegular|perm) + return &taskOwnedInode{Inode: inode, owner: task} +} + +// Open implements kernfs.Inode.Open. +func (s *statusInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + fd := &statusFD{ + inode: s, + task: s.task, + pidns: s.pidns, + userns: rp.Credentials().UserNamespace, + } + fd.LockFD.Init(&s.locks) + if err := fd.vfsfd.Init(fd, opts.Flags, rp.Mount(), d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { + return nil, err + } + fd.SetDataSource(fd) + return &fd.vfsfd, nil +} + +// SetStat implements kernfs.Inode.SetStat. +func (*statusInode) SetStat(ctx context.Context, vfsfs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { + return linuxerr.EPERM +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (s *statusFD) Release(ctx context.Context) { +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (s *statusFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := s.vfsfd.VirtualDentry().Mount().Filesystem() + return s.inode.Stat(ctx, fs, opts) +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (s *statusFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { + return linuxerr.EPERM +} // Generate implements vfs.DynamicBytesSource.Generate. -func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error { +func (s *statusFD) Generate(ctx context.Context, buf *bytes.Buffer) error { fmt.Fprintf(buf, "Name:\t%s\n", s.task.Name()) fmt.Fprintf(buf, "State:\t%s\n", s.task.StateStatus()) fmt.Fprintf(buf, "Tgid:\t%d\n", s.pidns.IDOfThreadGroup(s.task.ThreadGroup())) fmt.Fprintf(buf, "Pid:\t%d\n", s.pidns.IDOfTask(s.task)) + ppid := kernel.ThreadID(0) if parent := s.task.Parent(); parent != nil { ppid = s.pidns.IDOfThreadGroup(parent.ThreadGroup()) } fmt.Fprintf(buf, "PPid:\t%d\n", ppid) + tpid := kernel.ThreadID(0) if tracer := s.task.Tracer(); tracer != nil { tpid = s.pidns.IDOfTask(tracer) } fmt.Fprintf(buf, "TracerPid:\t%d\n", tpid) + + creds := s.task.Credentials() + ruid := creds.RealKUID.In(s.userns).OrOverflow() + euid := creds.EffectiveKUID.In(s.userns).OrOverflow() + suid := creds.SavedKUID.In(s.userns).OrOverflow() + rgid := creds.RealKGID.In(s.userns).OrOverflow() + egid := creds.EffectiveKGID.In(s.userns).OrOverflow() + sgid := creds.SavedKGID.In(s.userns).OrOverflow() var fds int var vss, rss, data uint64 s.task.WithMuLocked(func(t *kernel.Task) { @@ -701,12 +786,26 @@ func (s *statusData) Generate(ctx context.Context, buf *bytes.Buffer) error { data = mm.VirtualDataSize() } }) + // Filesystem user/group IDs aren't implemented; effective UID/GID are used + // instead. + fmt.Fprintf(buf, "Uid:\t%d\t%d\t%d\t%d\n", ruid, euid, suid, euid) + fmt.Fprintf(buf, "Gid:\t%d\t%d\t%d\t%d\n", rgid, egid, sgid, egid) fmt.Fprintf(buf, "FDSize:\t%d\n", fds) + buf.WriteString("Groups:\t ") + // There is a space between each pair of supplemental GIDs, as well as an + // unconditional trailing space that some applications actually depend on. + var sep string + for _, kgid := range creds.ExtraKGIDs { + fmt.Fprintf(buf, "%s%d", sep, kgid.In(s.userns).OrOverflow()) + sep = " " + } + buf.WriteString(" \n") + fmt.Fprintf(buf, "VmSize:\t%d kB\n", vss>>10) fmt.Fprintf(buf, "VmRSS:\t%d kB\n", rss>>10) fmt.Fprintf(buf, "VmData:\t%d kB\n", data>>10) + fmt.Fprintf(buf, "Threads:\t%d\n", s.task.ThreadGroup().Count()) - creds := s.task.Credentials() fmt.Fprintf(buf, "CapInh:\t%016x\n", creds.InheritableCaps) fmt.Fprintf(buf, "CapPrm:\t%016x\n", creds.PermittedCaps) fmt.Fprintf(buf, "CapEff:\t%016x\n", creds.EffectiveCaps) diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index 80dda1559..b121fc1b4 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -27,6 +27,9 @@ type Stack interface { // integers. Interfaces() map[int32]Interface + // RemoveInterface removes the specified network interface. + RemoveInterface(idx int32) error + // InterfaceAddrs returns all network interface addresses as a mapping from // interface indexes to a slice of associated interface address properties. InterfaceAddrs() map[int32][]InterfaceAddr diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 218d9dafc..621f47e1f 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -45,23 +45,29 @@ func NewTestStack() *TestStack { } } -// Interfaces implements Stack.Interfaces. +// Interfaces implements Stack. func (s *TestStack) Interfaces() map[int32]Interface { return s.InterfacesMap } -// InterfaceAddrs implements Stack.InterfaceAddrs. +// RemoveInterface implements Stack. +func (s *TestStack) RemoveInterface(idx int32) error { + delete(s.InterfacesMap, idx) + return nil +} + +// InterfaceAddrs implements Stack. func (s *TestStack) InterfaceAddrs() map[int32][]InterfaceAddr { return s.InterfaceAddrsMap } -// AddInterfaceAddr implements Stack.AddInterfaceAddr. +// AddInterfaceAddr implements Stack. func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error { s.InterfaceAddrsMap[idx] = append(s.InterfaceAddrsMap[idx], addr) return nil } -// RemoveInterfaceAddr implements Stack.RemoveInterfaceAddr. +// RemoveInterfaceAddr implements Stack. func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error { interfaceAddrs, ok := s.InterfaceAddrsMap[idx] if !ok { @@ -79,94 +85,94 @@ func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error { return nil } -// SupportsIPv6 implements Stack.SupportsIPv6. +// SupportsIPv6 implements Stack. func (s *TestStack) SupportsIPv6() bool { return s.SupportsIPv6Flag } -// TCPReceiveBufferSize implements Stack.TCPReceiveBufferSize. +// TCPReceiveBufferSize implements Stack. func (s *TestStack) TCPReceiveBufferSize() (TCPBufferSize, error) { return s.TCPRecvBufSize, nil } -// SetTCPReceiveBufferSize implements Stack.SetTCPReceiveBufferSize. +// SetTCPReceiveBufferSize implements Stack. func (s *TestStack) SetTCPReceiveBufferSize(size TCPBufferSize) error { s.TCPRecvBufSize = size return nil } -// TCPSendBufferSize implements Stack.TCPSendBufferSize. +// TCPSendBufferSize implements Stack. func (s *TestStack) TCPSendBufferSize() (TCPBufferSize, error) { return s.TCPSendBufSize, nil } -// SetTCPSendBufferSize implements Stack.SetTCPSendBufferSize. +// SetTCPSendBufferSize implements Stack. func (s *TestStack) SetTCPSendBufferSize(size TCPBufferSize) error { s.TCPSendBufSize = size return nil } -// TCPSACKEnabled implements Stack.TCPSACKEnabled. +// TCPSACKEnabled implements Stack. func (s *TestStack) TCPSACKEnabled() (bool, error) { return s.TCPSACKFlag, nil } -// SetTCPSACKEnabled implements Stack.SetTCPSACKEnabled. +// SetTCPSACKEnabled implements Stack. func (s *TestStack) SetTCPSACKEnabled(enabled bool) error { s.TCPSACKFlag = enabled return nil } -// TCPRecovery implements Stack.TCPRecovery. +// TCPRecovery implements Stack. func (s *TestStack) TCPRecovery() (TCPLossRecovery, error) { return s.Recovery, nil } -// SetTCPRecovery implements Stack.SetTCPRecovery. +// SetTCPRecovery implements Stack. func (s *TestStack) SetTCPRecovery(recovery TCPLossRecovery) error { s.Recovery = recovery return nil } -// Statistics implements inet.Stack.Statistics. +// Statistics implements Stack. func (s *TestStack) Statistics(stat interface{}, arg string) error { return nil } -// RouteTable implements Stack.RouteTable. +// RouteTable implements Stack. func (s *TestStack) RouteTable() []Route { return s.RouteList } -// Resume implements Stack.Resume. +// Resume implements Stack. func (s *TestStack) Resume() {} -// RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. +// RegisteredEndpoints implements Stack. func (s *TestStack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } -// CleanupEndpoints implements inet.Stack.CleanupEndpoints. +// CleanupEndpoints implements Stack. func (s *TestStack) CleanupEndpoints() []stack.TransportEndpoint { return nil } -// RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. +// RestoreCleanupEndpoints implements Stack. func (s *TestStack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} -// SetForwarding implements inet.Stack.SetForwarding. +// SetForwarding implements Stack. func (s *TestStack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { s.IPForwarding = enable return nil } -// PortRange implements inet.Stack.PortRange. +// PortRange implements Stack. func (*TestStack) PortRange() (uint16, uint16) { // Use the default Linux values per net/ipv4/af_inet.c:inet_init_net(). return 32768, 28232 } -// SetPortRange implements inet.Stack.SetPortRange. +// SetPortRange implements Stack. func (*TestStack) SetPortRange(start uint16, end uint16) error { // No-op. return nil diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index c613f4932..e4e0dc04f 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -220,6 +220,7 @@ go_library( "//pkg/abi/linux", "//pkg/abi/linux/errno", "//pkg/amutex", + "//pkg/bitmap", "//pkg/bits", "//pkg/bpf", "//pkg/cleanup", diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 8786a70b5..eff556a0c 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -18,10 +18,10 @@ import ( "fmt" "math" "strings" - "sync/atomic" "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/bitmap" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/errors/linuxerr" "gvisor.dev/gvisor/pkg/sentry/fs" @@ -84,13 +84,8 @@ type FDTable struct { // mu protects below. mu sync.Mutex `state:"nosave"` - // next is start position to find fd. - next int32 - - // used contains the number of non-nil entries. It must be accessed - // atomically. It may be read atomically without holding mu (but not - // written). - used int32 + // fdBitmap shows which fds are already in use. + fdBitmap bitmap.Bitmap `state:"nosave"` // descriptorTable holds descriptors. descriptorTable `state:".(map[int32]descriptor)"` @@ -98,6 +93,8 @@ type FDTable struct { func (f *FDTable) saveDescriptorTable() map[int32]descriptor { m := make(map[int32]descriptor) + f.mu.Lock() + defer f.mu.Unlock() f.forEach(context.Background(), func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { m[fd] = descriptor{ file: file, @@ -111,12 +108,16 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor { func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) { ctx := context.Background() f.initNoLeakCheck() // Initialize table. - f.used = 0 + f.fdBitmap = bitmap.New(uint32(math.MaxUint16)) for fd, d := range m { + if fd < 0 { + panic(fmt.Sprintf("FD is not supposed to be negative. FD: %d", fd)) + } + if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil { panic("VFS1 or VFS2 files set") } - + f.fdBitmap.Add(uint32(fd)) // Note that we do _not_ need to acquire a extra table reference here. The // table reference will already be accounted for in the file, so we drop the // reference taken by set above. @@ -189,8 +190,10 @@ func (f *FDTable) DecRef(ctx context.Context) { func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags)) { // retries tracks the number of failed TryIncRef attempts for the same FD. retries := 0 - fd := int32(0) - for { + fds := f.fdBitmap.ToSlice() + // Iterate through the fdBitmap. + for _, ufd := range fds { + fd := int32(ufd) file, fileVFS2, flags, ok := f.getAll(fd) if !ok { break @@ -218,7 +221,6 @@ func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File, fileVFS2.DecRef(ctx) } retries = 0 - fd++ } } @@ -226,6 +228,8 @@ func (f *FDTable) forEach(ctx context.Context, fn func(fd int32, file *fs.File, func (f *FDTable) String() string { var buf strings.Builder ctx := context.Background() + f.mu.Lock() + defer f.mu.Unlock() f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { switch { case file != nil: @@ -250,10 +254,10 @@ func (f *FDTable) String() string { } // NewFDs allocates new FDs guaranteed to be the lowest number available -// greater than or equal to the fd parameter. All files will share the set +// greater than or equal to the minFD parameter. All files will share the set // flags. Success is guaranteed to be all or none. -func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags FDFlags) (fds []int32, err error) { - if fd < 0 { +func (f *FDTable) NewFDs(ctx context.Context, minFD int32, files []*fs.File, flags FDFlags) (fds []int32, err error) { + if minFD < 0 { // Don't accept negative FDs. return nil, unix.EINVAL } @@ -267,31 +271,48 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags if lim.Cur != limits.Infinity { end = int32(lim.Cur) } - if fd >= end { + if minFD+int32(len(files)) > end { return nil, unix.EMFILE } } f.mu.Lock() - // From f.next to find available fd. - if fd < f.next { - fd = f.next + // max is used as the largest number in fdBitmap + 1. + max := int32(0) + + if !f.fdBitmap.IsEmpty() { + max = int32(f.fdBitmap.Maximum()) + max++ } + // Adjust max in case it is less than minFD. + if max < minFD { + max = minFD + } // Install all entries. - for i := fd; i < end && len(fds) < len(files); i++ { - if d, _, _ := f.get(i); d == nil { - // Set the descriptor. - f.set(ctx, i, files[len(fds)], flags) - fds = append(fds, i) // Record the file descriptor. + for len(fds) < len(files) { + // Try to use free bit in fdBitmap. + // If all bits in fdBitmap are used, expand fd to the max. + fd := f.fdBitmap.FirstZero(uint32(minFD)) + if fd == math.MaxInt32 { + fd = uint32(max) + max++ + } + if fd >= uint32(end) { + break } + f.fdBitmap.Add(fd) + f.set(ctx, int32(fd), files[len(fds)], flags) + fds = append(fds, int32(fd)) + minFD = int32(fd) } // Failure? Unwind existing FDs. if len(fds) < len(files) { for _, i := range fds { f.set(ctx, i, nil, FDFlags{}) + f.fdBitmap.Remove(uint32(i)) } f.mu.Unlock() @@ -305,20 +326,15 @@ func (f *FDTable) NewFDs(ctx context.Context, fd int32, files []*fs.File, flags return nil, unix.EMFILE } - if fd == f.next { - // Update next search start position. - f.next = fds[len(fds)-1] + 1 - } - f.mu.Unlock() return fds, nil } // NewFDsVFS2 allocates new FDs guaranteed to be the lowest number available -// greater than or equal to the fd parameter. All files will share the set +// greater than or equal to the minFD parameter. All files will share the set // flags. Success is guaranteed to be all or none. -func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDescription, flags FDFlags) (fds []int32, err error) { - if fd < 0 { +func (f *FDTable) NewFDsVFS2(ctx context.Context, minFD int32, files []*vfs.FileDescription, flags FDFlags) (fds []int32, err error) { + if minFD < 0 { // Don't accept negative FDs. return nil, unix.EINVAL } @@ -332,31 +348,47 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes if lim.Cur != limits.Infinity { end = int32(lim.Cur) } - if fd >= end { + if minFD >= end { return nil, unix.EMFILE } } f.mu.Lock() - // From f.next to find available fd. - if fd < f.next { - fd = f.next + // max is used as the largest number in fdBitmap + 1. + max := int32(0) + + if !f.fdBitmap.IsEmpty() { + max = int32(f.fdBitmap.Maximum()) + max++ } - // Install all entries. - for i := fd; i < end && len(fds) < len(files); i++ { - if d, _, _ := f.getVFS2(i); d == nil { - // Set the descriptor. - f.setVFS2(ctx, i, files[len(fds)], flags) - fds = append(fds, i) // Record the file descriptor. - } + // Adjust max in case it is less than minFD. + if max < minFD { + max = minFD } + for len(fds) < len(files) { + // Try to use free bit in fdBitmap. + // If all bits in fdBitmap are used, expand fd to the max. + fd := f.fdBitmap.FirstZero(uint32(minFD)) + if fd == math.MaxInt32 { + fd = uint32(max) + max++ + } + if fd >= uint32(end) { + break + } + f.fdBitmap.Add(fd) + f.setVFS2(ctx, int32(fd), files[len(fds)], flags) + fds = append(fds, int32(fd)) + minFD = int32(fd) + } // Failure? Unwind existing FDs. if len(fds) < len(files) { for _, i := range fds { f.setVFS2(ctx, i, nil, FDFlags{}) + f.fdBitmap.Remove(uint32(i)) } f.mu.Unlock() @@ -370,57 +402,19 @@ func (f *FDTable) NewFDsVFS2(ctx context.Context, fd int32, files []*vfs.FileDes return nil, unix.EMFILE } - if fd == f.next { - // Update next search start position. - f.next = fds[len(fds)-1] + 1 - } - f.mu.Unlock() return fds, nil } -// NewFDVFS2 allocates a file descriptor greater than or equal to minfd for +// NewFDVFS2 allocates a file descriptor greater than or equal to minFD for // the given file description. If it succeeds, it takes a reference on file. -func (f *FDTable) NewFDVFS2(ctx context.Context, minfd int32, file *vfs.FileDescription, flags FDFlags) (int32, error) { - if minfd < 0 { - // Don't accept negative FDs. - return -1, unix.EINVAL - } - - // Default limit. - end := int32(math.MaxInt32) - - // Ensure we don't get past the provided limit. - if limitSet := limits.FromContext(ctx); limitSet != nil { - lim := limitSet.Get(limits.NumberOfFiles) - if lim.Cur != limits.Infinity { - end = int32(lim.Cur) - } - if minfd >= end { - return -1, unix.EMFILE - } - } - - f.mu.Lock() - defer f.mu.Unlock() - - // From f.next to find available fd. - fd := minfd - if fd < f.next { - fd = f.next - } - for fd < end { - if d, _, _ := f.getVFS2(fd); d == nil { - f.setVFS2(ctx, fd, file, flags) - if fd == f.next { - // Update next search start position. - f.next = fd + 1 - } - return fd, nil - } - fd++ +func (f *FDTable) NewFDVFS2(ctx context.Context, minFD int32, file *vfs.FileDescription, flags FDFlags) (int32, error) { + files := []*vfs.FileDescription{file} + fileSlice, error := f.NewFDsVFS2(ctx, minFD, files, flags) + if error != nil { + return -1, error } - return -1, unix.EMFILE + return fileSlice[0], nil } // NewFDAt sets the file reference for the given FD. If there is an active @@ -469,6 +463,11 @@ func (f *FDTable) newFDAt(ctx context.Context, fd int32, file *fs.File, fileVFS2 defer f.mu.Unlock() df, dfVFS2 := f.setAll(ctx, fd, file, fileVFS2, flags) + // Add fd to fdBitmap. + if file != nil || fileVFS2 != nil { + f.fdBitmap.Add(uint32(fd)) + } + return df, dfVFS2, nil } @@ -573,7 +572,9 @@ func (f *FDTable) GetVFS2(fd int32) (*vfs.FileDescription, FDFlags) { // Precondition: The caller must be running on the task goroutine, or Task.mu // must be locked. func (f *FDTable) GetFDs(ctx context.Context) []int32 { - fds := make([]int32, 0, int(atomic.LoadInt32(&f.used))) + f.mu.Lock() + defer f.mu.Unlock() + fds := make([]int32, 0, int(f.fdBitmap.GetNumOnes())) f.forEach(ctx, func(fd int32, _ *fs.File, _ *vfs.FileDescription, _ FDFlags) { fds = append(fds, fd) }) @@ -583,13 +584,15 @@ func (f *FDTable) GetFDs(ctx context.Context) []int32 { // Fork returns an independent FDTable. func (f *FDTable) Fork(ctx context.Context) *FDTable { clone := f.k.NewFDTable() - + f.mu.Lock() + defer f.mu.Unlock() f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { // The set function here will acquire an appropriate table // reference for the clone. We don't need anything else. if df, dfVFS2 := clone.setAll(ctx, fd, file, fileVFS2, flags); df != nil || dfVFS2 != nil { panic("VFS1 or VFS2 files set") } + clone.fdBitmap.Add(uint32(fd)) }) return clone } @@ -604,11 +607,6 @@ func (f *FDTable) Remove(ctx context.Context, fd int32) (*fs.File, *vfs.FileDesc f.mu.Lock() - // Update current available position. - if fd < f.next { - f.next = fd - } - orig, orig2, _, _ := f.getAll(fd) // Add reference for caller. @@ -621,6 +619,7 @@ func (f *FDTable) Remove(ctx context.Context, fd int32) (*fs.File, *vfs.FileDesc if orig != nil || orig2 != nil { orig, orig2 = f.setAll(ctx, fd, nil, nil, FDFlags{}) // Zap entry. + f.fdBitmap.Remove(uint32(fd)) } f.mu.Unlock() @@ -644,16 +643,13 @@ func (f *FDTable) RemoveIf(ctx context.Context, cond func(*fs.File, *vfs.FileDes f.forEach(ctx, func(fd int32, file *fs.File, fileVFS2 *vfs.FileDescription, flags FDFlags) { if cond(file, fileVFS2, flags) { df, dfVFS2 := f.setAll(ctx, fd, nil, nil, FDFlags{}) // Clear from table. + f.fdBitmap.Remove(uint32(fd)) if df != nil { files = append(files, df) } if dfVFS2 != nil { filesVFS2 = append(filesVFS2, dfVFS2) } - // Update current available position. - if fd < f.next { - f.next = fd - } } }) f.mu.Unlock() diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go index f17f9c59c..2b3e6ef71 100644 --- a/pkg/sentry/kernel/fd_table_unsafe.go +++ b/pkg/sentry/kernel/fd_table_unsafe.go @@ -15,9 +15,11 @@ package kernel import ( + "math" "sync/atomic" "unsafe" + "gvisor.dev/gvisor/pkg/bitmap" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -44,6 +46,7 @@ func (f *FDTable) initNoLeakCheck() { func (f *FDTable) init() { f.initNoLeakCheck() f.InitRefs() + f.fdBitmap = bitmap.New(uint32(math.MaxUint16)) } // get gets a file entry. @@ -162,14 +165,6 @@ func (f *FDTable) setAll(ctx context.Context, fd int32, file *fs.File, fileVFS2 } } - // Adjust used. - switch { - case orig == nil && desc != nil: - atomic.AddInt32(&f.used, 1) - case orig != nil && desc == nil: - atomic.AddInt32(&f.used, -1) - } - if orig != nil { switch { case orig.file != nil: diff --git a/pkg/sentry/kernel/msgqueue/msgqueue.go b/pkg/sentry/kernel/msgqueue/msgqueue.go index 3ce926950..c111297d7 100644 --- a/pkg/sentry/kernel/msgqueue/msgqueue.go +++ b/pkg/sentry/kernel/msgqueue/msgqueue.go @@ -119,14 +119,21 @@ type Queue struct { type Message struct { msgEntry - // mType is an integer representing the type of the sent message. - mType int64 + // Type is an integer representing the type of the sent message. + Type int64 - // mText is an untyped block of memory. - mText []byte + // Text is an untyped block of memory. + Text []byte - // mSize is the size of mText. - mSize uint64 + // Size is the size of Text. + Size uint64 +} + +// Blocker is used for blocking Queue.Send, and Queue.Receive calls that serves +// as an abstracted version of kernel.Task. kernel.Task is not directly used to +// prevent circular dependencies. +type Blocker interface { + Block(C <-chan struct{}) error } // FindOrCreate creates a new message queue or returns an existing one. See @@ -186,6 +193,265 @@ func (r *Registry) Remove(id ipc.ID, creds *auth.Credentials) error { return nil } +// FindByID returns the queue with the specified ID and an error if the ID +// doesn't exist. +func (r *Registry) FindByID(id ipc.ID) (*Queue, error) { + r.mu.Lock() + defer r.mu.Unlock() + + mech := r.reg.FindByID(id) + if mech == nil { + return nil, linuxerr.EINVAL + } + return mech.(*Queue), nil +} + +// Send appends a message to the message queue, and returns an error if sending +// fails. See msgsnd(2). +func (q *Queue) Send(ctx context.Context, m Message, b Blocker, wait bool, pid int32) (err error) { + // Try to perform a non-blocking send using queue.append. If EWOULDBLOCK + // is returned, start the blocking procedure. Otherwise, return normally. + creds := auth.CredentialsFromContext(ctx) + if err := q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK { + return err + } + + if !wait { + return linuxerr.EAGAIN + } + + e, ch := waiter.NewChannelEntry(nil) + q.senders.EventRegister(&e, waiter.EventOut) + + for { + if err = q.append(ctx, m, creds, pid); err != linuxerr.EWOULDBLOCK { + break + } + b.Block(ch) + } + + q.senders.EventUnregister(&e) + return err +} + +// append appends a message to the queue's message list and notifies waiting +// receivers that a message has been inserted. It returns an error if adding +// the message would cause the queue to exceed its maximum capacity, which can +// be used as a signal to block the task. Other errors should be returned as is. +func (q *Queue) append(ctx context.Context, m Message, creds *auth.Credentials, pid int32) error { + if m.Type <= 0 { + return linuxerr.EINVAL + } + + q.mu.Lock() + defer q.mu.Unlock() + + if !q.obj.CheckPermissions(creds, fs.PermMask{Write: true}) { + // The calling process does not have write permission on the message + // queue, and does not have the CAP_IPC_OWNER capability in the user + // namespace that governs its IPC namespace. + return linuxerr.EACCES + } + + // Queue was removed while the process was waiting. + if q.dead { + return linuxerr.EIDRM + } + + // Check if sufficient space is available (the queue isn't full.) From + // the man pages: + // + // "A message queue is considered to be full if either of the following + // conditions is true: + // + // • Adding a new message to the queue would cause the total number + // of bytes in the queue to exceed the queue's maximum size (the + // msg_qbytes field). + // + // • Adding another message to the queue would cause the total + // number of messages in the queue to exceed the queue's maximum + // size (the msg_qbytes field). This check is necessary to + // prevent an unlimited number of zero-length messages being + // placed on the queue. Although such messages contain no data, + // they nevertheless consume (locked) kernel memory." + // + // The msg_qbytes field in our implementation is q.maxBytes. + if m.Size+q.byteCount > q.maxBytes || q.messageCount+1 > q.maxBytes { + return linuxerr.EWOULDBLOCK + } + + // Copy the message into the queue. + q.messages.PushBack(&m) + + q.byteCount += m.Size + q.messageCount++ + q.sendPID = pid + q.sendTime = ktime.NowFromContext(ctx) + + // Notify receivers about the new message. + q.receivers.Notify(waiter.EventIn) + + return nil +} + +// Receive removes a message from the queue and returns it. See msgrcv(2). +func (q *Queue) Receive(ctx context.Context, b Blocker, mType int64, maxSize int64, wait, truncate, except bool, pid int32) (msg *Message, err error) { + if maxSize < 0 || maxSize > maxMessageBytes { + return nil, linuxerr.EINVAL + } + max := uint64(maxSize) + + // Try to perform a non-blocking receive using queue.pop. If EWOULDBLOCK + // is returned, start the blocking procedure. Otherwise, return normally. + creds := auth.CredentialsFromContext(ctx) + if msg, err := q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK { + return msg, err + } + + if !wait { + return nil, linuxerr.ENOMSG + } + + e, ch := waiter.NewChannelEntry(nil) + q.receivers.EventRegister(&e, waiter.EventIn) + + for { + if msg, err = q.pop(ctx, creds, mType, max, truncate, except, pid); err != linuxerr.EWOULDBLOCK { + break + } + b.Block(ch) + } + q.receivers.EventUnregister(&e) + return msg, err +} + +// pop pops the first message from the queue that matches the given type. It +// returns an error for all the cases specified in msgrcv(2). If the queue is +// empty or no message of the specified type is available, a EWOULDBLOCK error +// is returned, which can then be used as a signal to block the process or fail. +func (q *Queue) pop(ctx context.Context, creds *auth.Credentials, mType int64, maxSize uint64, truncate, except bool, pid int32) (msg *Message, _ error) { + q.mu.Lock() + defer q.mu.Unlock() + + if !q.obj.CheckPermissions(creds, fs.PermMask{Read: true}) { + // The calling process does not have read permission on the message + // queue, and does not have the CAP_IPC_OWNER capability in the user + // namespace that governs its IPC namespace. + return nil, linuxerr.EACCES + } + + // Queue was removed while the process was waiting. + if q.dead { + return nil, linuxerr.EIDRM + } + + if q.messages.Empty() { + return nil, linuxerr.EWOULDBLOCK + } + + // Get a message from the queue. + switch { + case mType == 0: + msg = q.messages.Front() + case mType > 0: + msg = q.msgOfType(mType, except) + case mType < 0: + msg = q.msgOfTypeLessThan(-1 * mType) + } + + // If no message exists, return a blocking singal. + if msg == nil { + return nil, linuxerr.EWOULDBLOCK + } + + // Check message's size is acceptable. + if maxSize < msg.Size { + if !truncate { + return nil, linuxerr.E2BIG + } + msg.Size = maxSize + msg.Text = msg.Text[:maxSize+1] + } + + q.messages.Remove(msg) + + q.byteCount -= msg.Size + q.messageCount-- + q.receivePID = pid + q.receiveTime = ktime.NowFromContext(ctx) + + // Notify senders about available space. + q.senders.Notify(waiter.EventOut) + + return msg, nil +} + +// Copy copies a message from the queue without deleting it. If no message +// exists, an error is returned. See msgrcv(MSG_COPY). +func (q *Queue) Copy(mType int64) (*Message, error) { + q.mu.Lock() + defer q.mu.Unlock() + + if mType < 0 || q.messages.Empty() { + return nil, linuxerr.ENOMSG + } + + msg := q.msgAtIndex(mType) + if msg == nil { + return nil, linuxerr.ENOMSG + } + return msg, nil +} + +// msgOfType returns the first message with the specified type, nil if no +// message is found. If except is true, the first message of a type not equal +// to mType will be returned. +// +// Precondition: caller must hold q.mu. +func (q *Queue) msgOfType(mType int64, except bool) *Message { + if except { + for msg := q.messages.Front(); msg != nil; msg = msg.Next() { + if msg.Type != mType { + return msg + } + } + return nil + } + + for msg := q.messages.Front(); msg != nil; msg = msg.Next() { + if msg.Type == mType { + return msg + } + } + return nil +} + +// msgOfTypeLessThan return the the first message with the lowest type less +// than or equal to mType, nil if no such message exists. +// +// Precondition: caller must hold q.mu. +func (q *Queue) msgOfTypeLessThan(mType int64) (m *Message) { + min := mType + for msg := q.messages.Front(); msg != nil; msg = msg.Next() { + if msg.Type <= mType && msg.Type < min { + m = msg + min = msg.Type + } + } + return m +} + +// msgAtIndex returns a pointer to a message at given index, nil if non exits. +// +// Precondition: caller must hold q.mu. +func (q *Queue) msgAtIndex(mType int64) *Message { + msg := q.messages.Front() + for ; mType != 0 && msg != nil; mType-- { + msg = msg.Next() + } + return msg +} + // Lock implements ipc.Mechanism.Lock. func (q *Queue) Lock() { q.mu.Lock() diff --git a/pkg/sentry/platform/kvm/bluepill_fault.go b/pkg/sentry/platform/kvm/bluepill_fault.go index 28a613a54..8fd8287b3 100644 --- a/pkg/sentry/platform/kvm/bluepill_fault.go +++ b/pkg/sentry/platform/kvm/bluepill_fault.go @@ -101,7 +101,7 @@ func handleBluepillFault(m *machine, physical uintptr, phyRegions []physicalRegi // Store the physical address in the slot. This is used to // avoid calls to handleBluepillFault in the future (see // machine.mapPhysical). - atomic.StoreUintptr(&m.usedSlots[slot], physical) + atomic.StoreUintptr(&m.usedSlots[slot], physicalStart) // Successfully added region; we can increment nextSlot and // allow another set to proceed here. atomic.StoreUint32(&m.nextSlot, slot+1) diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index f63ab6aba..0f0c1e73b 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build go1.12 && !go1.18 -// +build go1.12,!go1.18 +//go:build go1.12 +// +build go1.12 -// Check go:linkname function signatures when updating Go version. +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. package kvm @@ -28,7 +30,7 @@ import ( ) //go:linkname throw runtime.throw -func throw(string) +func throw(s string) // vCPUPtr returns a CPU for the given address. // diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 1b5d5f66e..e7092a756 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -70,7 +70,7 @@ type machine struct { // tscControl checks whether cpu supports TSC scaling tscControl bool - // usedSlots is the set of used physical addresses (sorted). + // usedSlots is the set of used physical addresses (not sorted). usedSlots []uintptr // nextID is the next vCPU ID. @@ -296,13 +296,20 @@ func newMachine(vm int) (*machine, error) { return m, nil } -// hasSlot returns true iff the given address is mapped. +// hasSlot returns true if the given address is mapped. // // This must be done via a linear scan. // //go:nosplit func (m *machine) hasSlot(physical uintptr) bool { - for i := 0; i < len(m.usedSlots); i++ { + slotLen := int(atomic.LoadUint32(&m.nextSlot)) + // When slots are being updated, nextSlot is ^uint32(0). As this situation + // is less likely happen, we just set the slotLen to m.maxSlots, and scan + // the whole usedSlots array. + if slotLen == int(^uint32(0)) { + slotLen = m.maxSlots + } + for i := 0; i < slotLen; i++ { if p := atomic.LoadUintptr(&m.usedSlots[i]); p == physical { return true } diff --git a/pkg/sentry/platform/kvm/machine_unsafe.go b/pkg/sentry/platform/kvm/machine_unsafe.go index 35660e827..cc3a1253b 100644 --- a/pkg/sentry/platform/kvm/machine_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_unsafe.go @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build go1.12 && !go1.18 -// +build go1.12,!go1.18 +//go:build go1.12 +// +build go1.12 -// Check go:linkname function signatures when updating Go version. +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. package kvm diff --git a/pkg/sentry/platform/ptrace/subprocess_unsafe.go b/pkg/sentry/platform/ptrace/subprocess_unsafe.go index ffd4665f4..304722200 100644 --- a/pkg/sentry/platform/ptrace/subprocess_unsafe.go +++ b/pkg/sentry/platform/ptrace/subprocess_unsafe.go @@ -12,10 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build go1.12 && !go1.18 -// +build go1.12,!go1.18 +//go:build go1.12 +// +build go1.12 -// Check go:linkname function signatures when updating Go version. +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. package ptrace diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index 7a4e78a5f..61111ac6c 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -309,6 +309,11 @@ func (s *Stack) Interfaces() map[int32]inet.Interface { return interfaces } +// RemoveInterface implements inet.Stack.RemoveInterface. +func (*Stack) RemoveInterface(int32) error { + return linuxerr.EACCES +} + // InterfaceAddrs implements inet.Stack.InterfaceAddrs. func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { addrs := make(map[int32][]inet.InterfaceAddr) @@ -319,12 +324,12 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { } // AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. -func (s *Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error { +func (*Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error { return linuxerr.EACCES } // RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr. -func (s *Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error { +func (*Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error { return linuxerr.EACCES } @@ -339,7 +344,7 @@ func (s *Stack) TCPReceiveBufferSize() (inet.TCPBufferSize, error) { } // SetTCPReceiveBufferSize implements inet.Stack.SetTCPReceiveBufferSize. -func (s *Stack) SetTCPReceiveBufferSize(size inet.TCPBufferSize) error { +func (*Stack) SetTCPReceiveBufferSize(inet.TCPBufferSize) error { return linuxerr.EACCES } @@ -349,7 +354,7 @@ func (s *Stack) TCPSendBufferSize() (inet.TCPBufferSize, error) { } // SetTCPSendBufferSize implements inet.Stack.SetTCPSendBufferSize. -func (s *Stack) SetTCPSendBufferSize(size inet.TCPBufferSize) error { +func (*Stack) SetTCPSendBufferSize(inet.TCPBufferSize) error { return linuxerr.EACCES } @@ -359,7 +364,7 @@ func (s *Stack) TCPSACKEnabled() (bool, error) { } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. -func (s *Stack) SetTCPSACKEnabled(bool) error { +func (*Stack) SetTCPSACKEnabled(bool) error { return linuxerr.EACCES } @@ -369,7 +374,7 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { } // SetTCPRecovery implements inet.Stack.SetTCPRecovery. -func (s *Stack) SetTCPRecovery(inet.TCPLossRecovery) error { +func (*Stack) SetTCPRecovery(inet.TCPLossRecovery) error { return linuxerr.EACCES } @@ -470,19 +475,19 @@ func (s *Stack) RouteTable() []inet.Route { } // Resume implements inet.Stack.Resume. -func (s *Stack) Resume() {} +func (*Stack) Resume() {} // RegisteredEndpoints implements inet.Stack.RegisteredEndpoints. -func (s *Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } +func (*Stack) RegisteredEndpoints() []stack.TransportEndpoint { return nil } // CleanupEndpoints implements inet.Stack.CleanupEndpoints. -func (s *Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } +func (*Stack) CleanupEndpoints() []stack.TransportEndpoint { return nil } // RestoreCleanupEndpoints implements inet.Stack.RestoreCleanupEndpoints. -func (s *Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} +func (*Stack) RestoreCleanupEndpoints([]stack.TransportEndpoint) {} // SetForwarding implements inet.Stack.SetForwarding. -func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { +func (*Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { return linuxerr.EACCES } @@ -493,6 +498,6 @@ func (*Stack) PortRange() (uint16, uint16) { } // SetPortRange implements inet.Stack.SetPortRange. -func (*Stack) SetPortRange(start uint16, end uint16) error { +func (*Stack) SetPortRange(uint16, uint16) error { return linuxerr.EACCES } diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index 86f6419dc..d526acb73 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -161,6 +161,47 @@ func (p *Protocol) getLink(ctx context.Context, msg *netlink.Message, ms *netlin return nil } +// delLink handles RTM_DELLINK requests. +func (p *Protocol) delLink(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network stack. + return syserr.ErrProtocolNotSupported + } + + var ifinfomsg linux.InterfaceInfoMessage + attrs, ok := msg.GetData(&ifinfomsg) + if !ok { + return syserr.ErrInvalidArgument + } + if ifinfomsg.Index == 0 { + // The index is unspecified, search by the interface name. + ahdr, value, _, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument + } + switch ahdr.Type { + case linux.IFLA_IFNAME: + if len(value) < 1 { + return syserr.ErrInvalidArgument + } + ifname := string(value[:len(value)-1]) + for idx, ifa := range stack.Interfaces() { + if ifname == ifa.Name { + ifinfomsg.Index = idx + break + } + } + default: + return syserr.ErrInvalidArgument + } + if ifinfomsg.Index == 0 { + return syserr.ErrNoDevice + } + } + return syserr.FromError(stack.RemoveInterface(ifinfomsg.Index)) +} + // addNewLinkMessage appends RTM_NEWLINK message for the given interface into // the message set. func addNewLinkMessage(ms *netlink.MessageSet, idx int32, i inet.Interface) { @@ -537,6 +578,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms switch hdr.Type { case linux.RTM_GETLINK: return p.getLink(ctx, msg, ms) + case linux.RTM_DELLINK: + return p.delLink(ctx, msg, ms) case linux.RTM_GETROUTE: return p.dumpRoutes(ctx, msg, ms) case linux.RTM_NEWADDR: diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 0fd0ad32c..208ab9909 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -71,6 +71,12 @@ func (s *Stack) Interfaces() map[int32]inet.Interface { return is } +// RemoveInterface implements inet.Stack.RemoveInterface. +func (s *Stack) RemoveInterface(idx int32) error { + nic := tcpip.NICID(idx) + return syserr.TranslateNetstackError(s.Stack.RemoveNIC(nic)).ToError() +} + // InterfaceAddrs implements inet.Stack.InterfaceAddrs. func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { nicAddrs := make(map[int32][]inet.InterfaceAddr) diff --git a/pkg/sentry/syscalls/linux/BUILD b/pkg/sentry/syscalls/linux/BUILD index ccccce6a9..b5a371d9a 100644 --- a/pkg/sentry/syscalls/linux/BUILD +++ b/pkg/sentry/syscalls/linux/BUILD @@ -86,6 +86,7 @@ go_library( "//pkg/sentry/kernel/eventfd", "//pkg/sentry/kernel/fasync", "//pkg/sentry/kernel/ipc", + "//pkg/sentry/kernel/msgqueue", "//pkg/sentry/kernel/pipe", "//pkg/sentry/kernel/sched", "//pkg/sentry/kernel/shm", diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 6f44d767b..1ead3c7e8 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -122,8 +122,8 @@ var AMD64 = &kernel.SyscallTable{ 66: syscalls.Supported("semctl", Semctl), 67: syscalls.Supported("shmdt", Shmdt), 68: syscalls.Supported("msgget", Msgget), - 69: syscalls.ErrorWithEvent("msgsnd", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 70: syscalls.ErrorWithEvent("msgrcv", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 69: syscalls.Supported("msgsnd", Msgsnd), + 70: syscalls.Supported("msgrcv", Msgrcv), 71: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}), 72: syscalls.PartiallySupported("fcntl", Fcntl, "Not all options are supported.", nil), 73: syscalls.PartiallySupported("flock", Flock, "Locks are held within the sandbox only.", nil), @@ -618,8 +618,8 @@ var ARM64 = &kernel.SyscallTable{ 185: syscalls.ErrorWithEvent("mq_getsetattr", syserror.ENOSYS, "", []string{"gvisor.dev/issue/136"}), // TODO(b/29354921) 186: syscalls.Supported("msgget", Msgget), 187: syscalls.PartiallySupported("msgctl", Msgctl, "Only supports IPC_RMID option.", []string{"gvisor.dev/issue/135"}), - 188: syscalls.ErrorWithEvent("msgrcv", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) - 189: syscalls.ErrorWithEvent("msgsnd", linuxerr.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) + 188: syscalls.Supported("msgrcv", Msgrcv), + 189: syscalls.Supported("msgsnd", Msgsnd), 190: syscalls.Supported("semget", Semget), 191: syscalls.Supported("semctl", Semctl), 192: syscalls.Supported("semtimedop", Semtimedop), diff --git a/pkg/sentry/syscalls/linux/sys_msgqueue.go b/pkg/sentry/syscalls/linux/sys_msgqueue.go index 3476e218d..5259ade90 100644 --- a/pkg/sentry/syscalls/linux/sys_msgqueue.go +++ b/pkg/sentry/syscalls/linux/sys_msgqueue.go @@ -17,10 +17,12 @@ package linux import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/errors/linuxerr" + "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/sentry/arch" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/ipc" + "gvisor.dev/gvisor/pkg/sentry/kernel/msgqueue" ) // Msgget implements msgget(2). @@ -41,6 +43,89 @@ func Msgget(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal return uintptr(queue.ID()), nil, nil } +// Msgsnd implements msgsnd(2). +func Msgsnd(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + id := ipc.ID(args[0].Int()) + msgAddr := args[1].Pointer() + size := args[2].Int64() + flag := args[3].Int() + + if size < 0 || size > linux.MSGMAX { + return 0, nil, linuxerr.EINVAL + } + + wait := flag&linux.IPC_NOWAIT != linux.IPC_NOWAIT + pid := int32(t.ThreadGroup().ID()) + + buf := linux.MsgBuf{ + Text: make([]byte, size), + } + if _, err := buf.CopyIn(t, msgAddr); err != nil { + return 0, nil, err + } + + queue, err := t.IPCNamespace().MsgqueueRegistry().FindByID(id) + if err != nil { + return 0, nil, err + } + + msg := msgqueue.Message{ + Type: int64(buf.Type), + Text: buf.Text, + Size: uint64(size), + } + return 0, nil, queue.Send(t, msg, t, wait, pid) +} + +// Msgrcv implements msgrcv(2). +func Msgrcv(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { + id := ipc.ID(args[0].Int()) + msgAddr := args[1].Pointer() + size := args[2].Int64() + mType := args[3].Int64() + flag := args[4].Int() + + wait := flag&linux.IPC_NOWAIT != linux.IPC_NOWAIT + except := flag&linux.MSG_EXCEPT == linux.MSG_EXCEPT + truncate := flag&linux.MSG_NOERROR == linux.MSG_NOERROR + + msgCopy := flag&linux.MSG_COPY == linux.MSG_COPY + + msg, err := receive(t, id, mType, size, msgCopy, wait, truncate, except) + if err != nil { + return 0, nil, err + } + + buf := linux.MsgBuf{ + Type: primitive.Int64(msg.Type), + Text: msg.Text, + } + if _, err := buf.CopyOut(t, msgAddr); err != nil { + return 0, nil, err + } + return uintptr(msg.Size), nil, nil +} + +// receive returns a message from the queue with the given ID. If msgCopy is +// true, a message is copied from the queue without being removed. Otherwise, +// a message is removed from the queue and returned. +func receive(t *kernel.Task, id ipc.ID, mType int64, maxSize int64, msgCopy, wait, truncate, except bool) (*msgqueue.Message, error) { + pid := int32(t.ThreadGroup().ID()) + + queue, err := t.IPCNamespace().MsgqueueRegistry().FindByID(id) + if err != nil { + return nil, err + } + + if msgCopy { + if wait || except { + return nil, linuxerr.EINVAL + } + return queue.Copy(mType) + } + return queue.Receive(t, t, mType, maxSize, wait, truncate, except, pid) +} + // Msgctl implements msgctl(2). func Msgctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.SyscallControl, error) { id := ipc.ID(args[0].Int()) diff --git a/pkg/sentry/syscalls/linux/sys_prctl.go b/pkg/sentry/syscalls/linux/sys_prctl.go index a16b6b4d6..2ef1e6404 100644 --- a/pkg/sentry/syscalls/linux/sys_prctl.go +++ b/pkg/sentry/syscalls/linux/sys_prctl.go @@ -219,6 +219,21 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall } return 0, nil, t.DropBoundingCapability(cp) + case linux.PR_SET_CHILD_SUBREAPER: + // "If arg2 is nonzero, set the "child subreaper" attribute of + // the calling process; if arg2 is zero, unset the attribute." + // + // TODO(gvisor.dev/issues/2323): We only support setting, and + // only if the task is already TID 1 in the PID namespace, + // because it already acts as a subreaper in that case. + isPid1 := t.PIDNamespace().IDOfTask(t) == kernel.InitTID + if args[1].Int() != 0 && isPid1 { + return 0, nil, nil + } + + t.Kernel().EmitUnimplementedEvent(t) + return 0, nil, linuxerr.EINVAL + case linux.PR_GET_TIMING, linux.PR_SET_TIMING, linux.PR_GET_TSC, @@ -230,7 +245,6 @@ func Prctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscall linux.PR_MCE_KILL, linux.PR_MCE_KILL_GET, linux.PR_GET_TID_ADDRESS, - linux.PR_SET_CHILD_SUBREAPER, linux.PR_GET_CHILD_SUBREAPER, linux.PR_GET_THP_DISABLE, linux.PR_SET_THP_DISABLE, diff --git a/pkg/sync/goyield_unsafe.go b/pkg/sync/goyield_unsafe.go index 8639bb64e..757edbaba 100644 --- a/pkg/sync/goyield_unsafe.go +++ b/pkg/sync/goyield_unsafe.go @@ -3,10 +3,12 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -//go:build go1.14 && !go1.18 -// +build go1.14,!go1.18 +//go:build go1.14 +// +build go1.14 -// Check go:linkname function signatures when updating Go version. +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. package sync diff --git a/pkg/sync/runtime_unsafe.go b/pkg/sync/runtime_unsafe.go index 1d9cf304e..49d4109a9 100644 --- a/pkg/sync/runtime_unsafe.go +++ b/pkg/sync/runtime_unsafe.go @@ -6,8 +6,13 @@ //go:build go1.13 && !go1.18 // +build go1.13,!go1.18 -// Check go:linkname function signatures, type definitions, and constants when -// updating Go version. +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. + +// Check type definitions and constants when updating Go version. +// +// TODO(b/165820485): add these checks to checklinkname. package sync @@ -109,10 +114,10 @@ type maptype struct { // These functions are only used within the sync package. //go:linkname semacquire sync.runtime_Semacquire -func semacquire(s *uint32) +func semacquire(addr *uint32) //go:linkname semrelease sync.runtime_Semrelease -func semrelease(s *uint32, handoff bool, skipframes int) +func semrelease(addr *uint32, handoff bool, skipframes int) //go:linkname canSpin sync.runtime_canSpin func canSpin(i int) bool diff --git a/pkg/tcpip/link/fdbased/endpoint.go b/pkg/tcpip/link/fdbased/endpoint.go index e8e716db0..48356c343 100644 --- a/pkg/tcpip/link/fdbased/endpoint.go +++ b/pkg/tcpip/link/fdbased/endpoint.go @@ -56,6 +56,7 @@ import ( // linkDispatcher reads packets from the link FD and dispatches them to the // NetworkDispatcher. type linkDispatcher interface { + stop() dispatch() (bool, tcpip.Error) } @@ -381,16 +382,27 @@ func isSocketFD(fd int) (bool, error) { // Attach launches the goroutine that reads packets from the file descriptor and // dispatches them via the provided dispatcher. func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { - e.dispatcher = dispatcher - // Link endpoints are not savable. When transportation endpoints are - // saved, they stop sending outgoing packets and all incoming packets - // are rejected. - for i := range e.inboundDispatchers { - e.wg.Add(1) - go func(i int) { // S/R-SAFE: See above. - e.dispatchLoop(e.inboundDispatchers[i]) - e.wg.Done() - }(i) + // nil means the NIC is being removed. + if dispatcher == nil && e.dispatcher != nil { + for _, dispatcher := range e.inboundDispatchers { + dispatcher.stop() + } + e.Wait() + e.dispatcher = nil + return + } + if dispatcher != nil && e.dispatcher == nil { + e.dispatcher = dispatcher + // Link endpoints are not savable. When transportation endpoints are + // saved, they stop sending outgoing packets and all incoming packets + // are rejected. + for i := range e.inboundDispatchers { + e.wg.Add(1) + go func(i int) { // S/R-SAFE: See above. + e.dispatchLoop(e.inboundDispatchers[i]) + e.wg.Done() + }(i) + } } } diff --git a/pkg/tcpip/link/fdbased/mmap.go b/pkg/tcpip/link/fdbased/mmap.go index bfae34ab9..3f516cab5 100644 --- a/pkg/tcpip/link/fdbased/mmap.go +++ b/pkg/tcpip/link/fdbased/mmap.go @@ -114,6 +114,7 @@ func (t tPacketHdr) Payload() []byte { // packetMMapDispatcher uses PACKET_RX_RING's to read/dispatch inbound packets. // See: mmap_amd64_unsafe.go for implementation details. type packetMMapDispatcher struct { + stopFd // fd is the file descriptor used to send and receive packets. fd int @@ -129,18 +130,18 @@ type packetMMapDispatcher struct { ringOffset int } -func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) { +func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, bool, tcpip.Error) { hdr := tPacketHdr(d.ringBuffer[d.ringOffset*tpFrameSize:]) for hdr.tpStatus()&tpStatusUser == 0 { - event := rawfile.PollEvent{ - FD: int32(d.fd), - Events: unix.POLLIN | unix.POLLERR, - } - if _, errno := rawfile.BlockingPoll(&event, 1, nil); errno != 0 { + stopped, errno := rawfile.BlockingPollUntilStopped(d.efd, d.fd, unix.POLLIN|unix.POLLERR) + if errno != 0 { if errno == unix.EINTR { continue } - return nil, rawfile.TranslateErrno(errno) + return nil, stopped, rawfile.TranslateErrno(errno) + } + if stopped { + return nil, true, nil } if hdr.tpStatus()&tpStatusCopy != 0 { // This frame is truncated so skip it after flipping the @@ -158,14 +159,14 @@ func (d *packetMMapDispatcher) readMMappedPacket() ([]byte, tcpip.Error) { // Release packet to kernel. hdr.setTPStatus(tpStatusKernel) d.ringOffset = (d.ringOffset + 1) % tpFrameNR - return pkt, nil + return pkt, false, nil } // dispatch reads packets from an mmaped ring buffer and dispatches them to the // network stack. func (d *packetMMapDispatcher) dispatch() (bool, tcpip.Error) { - pkt, err := d.readMMappedPacket() - if err != nil { + pkt, stopped, err := d.readMMappedPacket() + if err != nil || stopped { return false, err } var ( diff --git a/pkg/tcpip/link/fdbased/mmap_unsafe.go b/pkg/tcpip/link/fdbased/mmap_unsafe.go index 58d5dfeef..5b786169a 100644 --- a/pkg/tcpip/link/fdbased/mmap_unsafe.go +++ b/pkg/tcpip/link/fdbased/mmap_unsafe.go @@ -47,9 +47,14 @@ func (t tPacketHdr) setTPStatus(status uint32) { } func newPacketMMapDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + stopFd, err := newStopFd() + if err != nil { + return nil, err + } d := &packetMMapDispatcher{ - fd: fd, - e: e, + stopFd: stopFd, + fd: fd, + e: e, } pageSize := unix.Getpagesize() if tpBlockSize%pageSize != 0 { diff --git a/pkg/tcpip/link/fdbased/packet_dispatchers.go b/pkg/tcpip/link/fdbased/packet_dispatchers.go index ab2855a63..fab34c5fa 100644 --- a/pkg/tcpip/link/fdbased/packet_dispatchers.go +++ b/pkg/tcpip/link/fdbased/packet_dispatchers.go @@ -18,6 +18,8 @@ package fdbased import ( + "fmt" + "golang.org/x/sys/unix" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" @@ -114,9 +116,36 @@ func (b *iovecBuffer) pullViews(n int) buffer.VectorisedView { return buffer.NewVectorisedView(n, views) } +// stopFd is an eventfd used to signal the stop of a dispatcher. +type stopFd struct { + efd int +} + +func newStopFd() (stopFd, error) { + efd, err := unix.Eventfd(0, unix.EFD_NONBLOCK) + if err != nil { + return stopFd{efd: -1}, fmt.Errorf("failed to create eventfd: %w", err) + } + return stopFd{efd: efd}, nil +} + +// stop writes to the eventfd and notifies the dispatcher to stop. It does not +// block. +func (s *stopFd) stop() { + increment := []byte{1, 0, 0, 0, 0, 0, 0, 0} + if n, err := unix.Write(s.efd, increment); n != len(increment) || err != nil { + // There are two possible errors documented in eventfd(2) for writing: + // 1. We are writing 8 bytes and not 0xffffffffffffff, thus no EINVAL. + // 2. stop is only supposed to be called once, it can't reach the limit, + // thus no EAGAIN. + panic(fmt.Sprintf("write(efd) = (%d, %s), want (%d, nil)", n, err, len(increment))) + } +} + // readVDispatcher uses readv() system call to read inbound packets and // dispatches them. type readVDispatcher struct { + stopFd // fd is the file descriptor used to send and receive packets. fd int @@ -128,7 +157,15 @@ type readVDispatcher struct { } func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { - d := &readVDispatcher{fd: fd, e: e} + stopFd, err := newStopFd() + if err != nil { + return nil, err + } + d := &readVDispatcher{ + stopFd: stopFd, + fd: fd, + e: e, + } skipsVnetHdr := d.e.gsoKind == stack.HWGSOSupported d.buf = newIovecBuffer(BufConfig, skipsVnetHdr) return d, nil @@ -136,8 +173,8 @@ func newReadVDispatcher(fd int, e *endpoint) (linkDispatcher, error) { // dispatch reads one packet from the file descriptor and dispatches it. func (d *readVDispatcher) dispatch() (bool, tcpip.Error) { - n, err := rawfile.BlockingReadv(d.fd, d.buf.nextIovecs()) - if n == 0 || err != nil { + n, err := rawfile.BlockingReadvUntilStopped(d.efd, d.fd, d.buf.nextIovecs()) + if n <= 0 || err != nil { return false, err } @@ -184,6 +221,7 @@ func (d *readVDispatcher) dispatch() (bool, tcpip.Error) { // recvMMsgDispatcher uses the recvmmsg system call to read inbound packets and // dispatches them. type recvMMsgDispatcher struct { + stopFd // fd is the file descriptor used to send and receive packets. fd int @@ -207,7 +245,12 @@ const ( ) func newRecvMMsgDispatcher(fd int, e *endpoint) (linkDispatcher, error) { + stopFd, err := newStopFd() + if err != nil { + return nil, err + } d := &recvMMsgDispatcher{ + stopFd: stopFd, fd: fd, e: e, bufs: make([]*iovecBuffer, MaxMsgsPerRecv), @@ -235,8 +278,8 @@ func (d *recvMMsgDispatcher) dispatch() (bool, tcpip.Error) { d.msgHdrs[k].Msg.SetIovlen(iovLen) } - nMsgs, err := rawfile.BlockingRecvMMsg(d.fd, d.msgHdrs) - if err != nil { + nMsgs, err := rawfile.BlockingRecvMMsgUntilStopped(d.efd, d.fd, d.msgHdrs) + if nMsgs == -1 || err != nil { return false, err } // Process each of received packets. diff --git a/pkg/tcpip/link/qdisc/fifo/endpoint.go b/pkg/tcpip/link/qdisc/fifo/endpoint.go index b1a28491d..40bd5560b 100644 --- a/pkg/tcpip/link/qdisc/fifo/endpoint.go +++ b/pkg/tcpip/link/qdisc/fifo/endpoint.go @@ -115,6 +115,13 @@ func (e *endpoint) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protoc // Attach implements stack.LinkEndpoint.Attach. func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { + // nil means the NIC is being removed. + if dispatcher == nil { + e.lower.Attach(nil) + e.Wait() + e.dispatcher = nil + return + } e.dispatcher = dispatcher e.lower.Attach(e) } diff --git a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go index da900c24b..0b7b9e3de 100644 --- a/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go +++ b/pkg/tcpip/link/rawfile/blockingpoll_yield_unsafe.go @@ -12,12 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. -//go:build ((linux && amd64) || (linux && arm64)) && go1.12 && !go1.18 +//go:build ((linux && amd64) || (linux && arm64)) && go1.12 // +build linux,amd64 linux,arm64 // +build go1.12 -// +build !go1.18 -// Check go:linkname function signatures when updating Go version. +// //go:linkname directives type-checked by checklinkname. Any other +// non-linkname assumptions outside the Go 1 compatibility guarantee should +// have an accompanied vet check or version guard build tag. package rawfile diff --git a/pkg/tcpip/link/rawfile/rawfile_unsafe.go b/pkg/tcpip/link/rawfile/rawfile_unsafe.go index 53448a641..e76fc55b6 100644 --- a/pkg/tcpip/link/rawfile/rawfile_unsafe.go +++ b/pkg/tcpip/link/rawfile/rawfile_unsafe.go @@ -170,46 +170,63 @@ func BlockingRead(fd int, b []byte) (int, tcpip.Error) { } } -// BlockingReadv reads from a file descriptor that is set up as non-blocking and -// stores the data in a list of iovecs buffers. If no data is available, it will -// block in a poll() syscall until the file descriptor becomes readable. -func BlockingReadv(fd int, iovecs []unix.Iovec) (int, tcpip.Error) { +// BlockingReadvUntilStopped reads from a file descriptor that is set up as +// non-blocking and stores the data in a list of iovecs buffers. If no data is +// available, it will block in a poll() syscall until the file descriptor +// becomes readable or stop is signalled (efd becomes readable). Returns -1 in +// the latter case. +func BlockingReadvUntilStopped(efd int, fd int, iovecs []unix.Iovec) (int, tcpip.Error) { for { n, _, e := unix.RawSyscall(unix.SYS_READV, uintptr(fd), uintptr(unsafe.Pointer(&iovecs[0])), uintptr(len(iovecs))) if e == 0 { return int(n), nil } - event := PollEvent{ - FD: int32(fd), - Events: 1, // POLLIN + stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN) + if stopped { + return -1, nil } - - _, e = BlockingPoll(&event, 1, nil) if e != 0 && e != unix.EINTR { return 0, TranslateErrno(e) } } } -// BlockingRecvMMsg reads from a file descriptor that is set up as non-blocking -// and stores the received messages in a slice of MMsgHdr structures. If no data -// is available, it will block in a poll() syscall until the file descriptor -// becomes readable. -func BlockingRecvMMsg(fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) { +// BlockingRecvMMsgUntilStopped reads from a file descriptor that is set up as +// non-blocking and stores the received messages in a slice of MMsgHdr +// structures. If no data is available, it will block in a poll() syscall until +// the file descriptor becomes readable or stop is signalled (efd becomes +// readable). Returns -1 in the latter case. +func BlockingRecvMMsgUntilStopped(efd int, fd int, msgHdrs []MMsgHdr) (int, tcpip.Error) { for { n, _, e := unix.RawSyscall6(unix.SYS_RECVMMSG, uintptr(fd), uintptr(unsafe.Pointer(&msgHdrs[0])), uintptr(len(msgHdrs)), unix.MSG_DONTWAIT, 0, 0) if e == 0 { return int(n), nil } - event := PollEvent{ - FD: int32(fd), - Events: 1, // POLLIN + stopped, e := BlockingPollUntilStopped(efd, fd, unix.POLLIN) + if stopped { + return -1, nil } - - if _, e := BlockingPoll(&event, 1, nil); e != 0 && e != unix.EINTR { + if e != 0 && e != unix.EINTR { return 0, TranslateErrno(e) } } } + +// BlockingPollUntilStopped polls for events on fd or until a stop is signalled +// on the event fd efd. Returns true if stopped, i.e., efd has event POLLIN. +func BlockingPollUntilStopped(efd int, fd int, events int16) (bool, unix.Errno) { + pevents := [...]PollEvent{ + { + FD: int32(efd), + Events: unix.POLLIN, + }, + { + FD: int32(fd), + Events: events, + }, + } + _, errno := BlockingPoll(&pevents[0], len(pevents), nil) + return pevents[0].Revents&unix.POLLIN != 0, errno +} diff --git a/pkg/tcpip/link/sniffer/pcap.go b/pkg/tcpip/link/sniffer/pcap.go index 3bb864ed2..d3edede63 100644 --- a/pkg/tcpip/link/sniffer/pcap.go +++ b/pkg/tcpip/link/sniffer/pcap.go @@ -14,7 +14,14 @@ package sniffer -import "time" +import ( + "encoding" + "encoding/binary" + "time" + + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/stack" +) type pcapHeader struct { // MagicNumber is the file magic number. @@ -39,25 +46,38 @@ type pcapHeader struct { Network uint32 } -type pcapPacketHeader struct { - // Seconds is the timestamp seconds. - Seconds uint32 - - // Microseconds is the timestamp microseconds. - Microseconds uint32 +var _ encoding.BinaryMarshaler = (*pcapPacket)(nil) - // IncludedLength is the number of octets of packet saved in file. - IncludedLength uint32 - - // OriginalLength is the actual length of packet. - OriginalLength uint32 +type pcapPacket struct { + timestamp time.Time + packet *stack.PacketBuffer + maxCaptureLen int } -func newPCAPPacketHeader(now time.Time, incLen, orgLen uint32) pcapPacketHeader { - return pcapPacketHeader{ - Seconds: uint32(now.Unix()), - Microseconds: uint32(now.Nanosecond() / 1000), - IncludedLength: incLen, - OriginalLength: orgLen, +func (p *pcapPacket) MarshalBinary() ([]byte, error) { + packetSize := p.packet.Size() + captureLen := p.maxCaptureLen + if packetSize < captureLen { + captureLen = packetSize + } + b := make([]byte, 16+captureLen) + binary.BigEndian.PutUint32(b[0:4], uint32(p.timestamp.Unix())) + binary.BigEndian.PutUint32(b[4:8], uint32(p.timestamp.Nanosecond()/1000)) + binary.BigEndian.PutUint32(b[8:12], uint32(captureLen)) + binary.BigEndian.PutUint32(b[12:16], uint32(packetSize)) + w := tcpip.SliceWriter(b[16:]) + for _, v := range p.packet.Views() { + if captureLen == 0 { + break + } + if len(v) > captureLen { + v = v[:captureLen] + } + n, err := w.Write(v) + if err != nil { + panic(err) + } + captureLen -= n } + return b, nil } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 3df826f3c..28a172e71 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -151,33 +151,16 @@ func (e *endpoint) dumpPacket(dir direction, protocol tcpip.NetworkProtocolNumbe logPacket(e.logPrefix, dir, protocol, pkt) } if writer != nil && atomic.LoadUint32(&LogPacketsToPCAP) == 1 { - totalLength := pkt.Size() - length := totalLength - if max := int(e.maxPCAPLen); length > max { - length = max + packet := pcapPacket{ + timestamp: time.Now(), + packet: pkt, + maxCaptureLen: int(e.maxPCAPLen), } - packetHeader := newPCAPPacketHeader(time.Now(), uint32(length), uint32(totalLength)) - packet := make([]byte, binary.Size(packetHeader)+length) - { - writer := tcpip.SliceWriter(packet) - if err := binary.Write(&writer, binary.BigEndian, packetHeader); err != nil { - panic(err) - } - for _, b := range pkt.Views() { - if length == 0 { - break - } - if len(b) > length { - b = b[:length] - } - n, err := writer.Write(b) - if err != nil { - panic(err) - } - length -= n - } + b, err := packet.MarshalBinary() + if err != nil { + panic(err) } - if _, err := writer.Write(packet); err != nil { + if _, err := writer.Write(b); err != nil { panic(err) } } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index c90974693..2257f728e 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -39,7 +39,6 @@ go_test( "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", - "//pkg/tcpip/link/loopback", "//pkg/tcpip/link/sniffer", "//pkg/tcpip/network/arp", "//pkg/tcpip/network/internal/testutil", diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index 4a4448cf9..73407be67 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -32,7 +32,6 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" - "gvisor.dev/gvisor/pkg/tcpip/link/loopback" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" "gvisor.dev/gvisor/pkg/tcpip/network/arp" iptestutil "gvisor.dev/gvisor/pkg/tcpip/network/internal/testutil" @@ -3339,7 +3338,7 @@ func TestCloseLocking(t *testing.T) { defer wg.Done() for i := 0; i < iterations; i++ { - if err := s.CreateNIC(nicID2, loopback.New()); err != nil { + if err := s.CreateNIC(nicID2, stack.LinkEndpoint(channel.New(0, defaultMTU, ""))); err != nil { t.Errorf("CreateNIC(%d, _): %s", nicID2, err) return } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 81fabe29a..c73890c4c 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -780,6 +780,9 @@ func (s *Stack) removeNICLocked(id tcpip.NICID) tcpip.Error { if !ok { return &tcpip.ErrUnknownNICID{} } + if nic.IsLoopback() { + return &tcpip.ErrNotSupported{} + } delete(s.nics, id) // Remove routes in-place. n tracks the number of routes written. diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index 21951d05a..3089c0ef4 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -719,38 +719,59 @@ func TestRemoveUnknownNIC(t *testing.T) { } func TestRemoveNIC(t *testing.T) { - const nicID = 1 + for _, tt := range []struct { + name string + linkep stack.LinkEndpoint + expectErr tcpip.Error + }{ + { + name: "loopback", + linkep: loopback.New(), + expectErr: &tcpip.ErrNotSupported{}, + }, + { + name: "channel", + linkep: channel.New(0, defaultMTU, ""), + expectErr: nil, + }, + } { + t.Run(tt.name, func(t *testing.T) { + const nicID = 1 - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, - }) + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + }) - e := linkEPWithMockedAttach{ - LinkEndpoint: loopback.New(), - } - if err := s.CreateNIC(nicID, &e); err != nil { - t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) - } + e := linkEPWithMockedAttach{ + LinkEndpoint: tt.linkep, + } + if err := s.CreateNIC(nicID, &e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } - // NIC should be present in NICInfo and attached to a NetworkDispatcher. - allNICInfo := s.NICInfo() - if _, ok := allNICInfo[nicID]; !ok { - t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) - } - if !e.isAttached() { - t.Fatal("link endpoint not attached to a network dispatcher") - } + // NIC should be present in NICInfo and attached to a NetworkDispatcher. + allNICInfo := s.NICInfo() + if _, ok := allNICInfo[nicID]; !ok { + t.Errorf("entry for %d missing from allNICInfo = %+v", nicID, allNICInfo) + } + if !e.isAttached() { + t.Fatal("link endpoint not attached to a network dispatcher") + } - // Removing a NIC should remove it from NICInfo and e should be detached from - // the NetworkDispatcher. - if err := s.RemoveNIC(nicID); err != nil { - t.Fatalf("s.RemoveNIC(%d): %s", nicID, err) - } - if nicInfo, ok := s.NICInfo()[nicID]; ok { - t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) - } - if e.isAttached() { - t.Error("link endpoint for removed NIC still attached to a network dispatcher") + // Removing a NIC should remove it from NICInfo and e should be detached from + // the NetworkDispatcher. + if got, want := s.RemoveNIC(nicID), tt.expectErr; got != want { + t.Fatalf("got s.RemoveNIC(%d) = %s, want %s", nicID, got, want) + } + if tt.expectErr == nil { + if nicInfo, ok := s.NICInfo()[nicID]; ok { + t.Errorf("got unexpected NICInfo entry for deleted NIC %d = %+v", nicID, nicInfo) + } + if e.isAttached() { + t.Error("link endpoint for removed NIC still attached to a network dispatcher") + } + } + }) } } diff --git a/runsc/boot/controller.go b/runsc/boot/controller.go index e5b0ec3ae..60b532798 100644 --- a/runsc/boot/controller.go +++ b/runsc/boot/controller.go @@ -57,20 +57,12 @@ const ( // ContMgrExecuteAsync executes a command in a container. ContMgrExecuteAsync = "containerManager.ExecuteAsync" - // ContMgrPause pauses the sandbox (note that individual containers cannot be - // paused). - ContMgrPause = "containerManager.Pause" - // ContMgrProcesses lists processes running in a container. ContMgrProcesses = "containerManager.Processes" // ContMgrRestore restores a container from a statefile. ContMgrRestore = "containerManager.Restore" - // ContMgrResume unpauses the paused sandbox (note that individual containers - // cannot be resumed). - ContMgrResume = "containerManager.Resume" - // ContMgrSignal sends a signal to a container. ContMgrSignal = "containerManager.Signal" @@ -111,6 +103,17 @@ const ( LoggingChange = "Logging.Change" ) +// Lifecycle related commands (see lifecycle.go for more details). +const ( + LifecyclePause = "Lifecycle.Pause" + LifecycleResume = "Lifecycle.Resume" +) + +// Filesystem related commands (see fs.go for more details). +const ( + FsCat = "Fs.Cat" +) + // ControlSocketAddr generates an abstract unix socket name for the given ID. func ControlSocketAddr(id string) string { return fmt.Sprintf("\x00runsc-sandbox.%s", id) @@ -152,6 +155,8 @@ func newController(fd int, l *Loader) (*controller, error) { ctrl.srv.Register(&debug{}) ctrl.srv.Register(&control.Logging{}) + ctrl.srv.Register(&control.Lifecycle{l.k}) + ctrl.srv.Register(&control.Fs{l.k}) if l.root.conf.ProfileEnable { ctrl.srv.Register(control.NewProfile(l.k)) @@ -340,17 +345,6 @@ func (cm *containerManager) Checkpoint(o *control.SaveOpts, _ *struct{}) error { return state.Save(o, nil) } -// Pause suspends a sandbox. -func (cm *containerManager) Pause(_, _ *struct{}) error { - log.Debugf("containerManager.Pause") - // TODO(gvisor.dev/issues/6243): save/restore not supported w/ hostinet - if cm.l.root.conf.Network == config.NetworkHost { - return errors.New("pause not supported when using hostinet") - } - cm.l.k.Pause() - return nil -} - // RestoreOpts contains options related to restoring a container's file system. type RestoreOpts struct { // FilePayload contains the state file to be restored, followed by the @@ -482,13 +476,6 @@ func (cm *containerManager) Restore(o *RestoreOpts, _ *struct{}) error { return nil } -// Resume unpauses a sandbox. -func (cm *containerManager) Resume(_, _ *struct{}) error { - log.Debugf("containerManager.Resume") - cm.l.k.Unpause() - return nil -} - // Wait waits for the init process in the given container. func (cm *containerManager) Wait(cid *string, waitStatus *uint32) error { log.Debugf("containerManager.Wait, cid: %s", *cid) diff --git a/runsc/cmd/chroot.go b/runsc/cmd/chroot.go index 7b11b3367..1fe9c6435 100644 --- a/runsc/cmd/chroot.go +++ b/runsc/cmd/chroot.go @@ -59,6 +59,23 @@ func pivotRoot(root string) error { return nil } +func copyFile(dst, src string) error { + in, err := os.Open(src) + if err != nil { + return err + } + defer in.Close() + + out, err := os.Create(dst) + if err != nil { + return err + } + defer out.Close() + + _, err = out.ReadFrom(in) + return err +} + // setUpChroot creates an empty directory with runsc mounted at /runsc and proc // mounted at /proc. func setUpChroot(pidns bool) error { @@ -78,6 +95,14 @@ func setUpChroot(pidns bool) error { return fmt.Errorf("error mounting tmpfs in choot: %v", err) } + if err := os.Mkdir(filepath.Join(chroot, "etc"), 0755); err != nil { + return fmt.Errorf("error creating /etc in chroot: %v", err) + } + + if err := copyFile(filepath.Join(chroot, "etc/localtime"), "/etc/localtime"); err != nil { + log.Warningf("Failed to copy /etc/localtime: %v. UTC timezone will be used.", err) + } + if pidns { flags := uint32(unix.MS_NOSUID | unix.MS_NODEV | unix.MS_NOEXEC | unix.MS_RDONLY) if err := mountInChroot(chroot, "proc", "/proc", "proc", flags); err != nil { diff --git a/runsc/cmd/debug.go b/runsc/cmd/debug.go index da81cf048..f773ccca0 100644 --- a/runsc/cmd/debug.go +++ b/runsc/cmd/debug.go @@ -48,6 +48,7 @@ type Debug struct { delay time.Duration duration time.Duration ps bool + cat stringSlice } // Name implements subcommands.Command. @@ -81,6 +82,7 @@ func (d *Debug) SetFlags(f *flag.FlagSet) { f.StringVar(&d.logLevel, "log-level", "", "The log level to set: warning (0), info (1), or debug (2).") f.StringVar(&d.logPackets, "log-packets", "", "A boolean value to enable or disable packet logging: true or false.") f.BoolVar(&d.ps, "ps", false, "lists processes") + f.Var(&d.cat, "cat", "reads files and print to standard output") } // Execute implements subcommands.Command.Execute. @@ -367,5 +369,11 @@ func (d *Debug) Execute(_ context.Context, f *flag.FlagSet, args ...interface{}) return subcommands.ExitFailure } + if d.cat != nil { + if err := c.Cat(d.cat, os.Stdout); err != nil { + return Errorf("Cat failed: %v", err) + } + } + return subcommands.ExitSuccess } diff --git a/runsc/cmd/gofer.go b/runsc/cmd/gofer.go index 20e05f141..2193e9040 100644 --- a/runsc/cmd/gofer.go +++ b/runsc/cmd/gofer.go @@ -285,16 +285,22 @@ func setupRootFS(spec *specs.Spec, conf *config.Config) error { // Prepare tree structure for pivot_root(2). if err := os.Mkdir("/proc/proc", 0755); err != nil { - Fatalf("%v", err) + Fatalf("error creating /proc/proc: %v", err) } if err := os.Mkdir("/proc/root", 0755); err != nil { - Fatalf("%v", err) + Fatalf("error creating /proc/root: %v", err) + } + if err := os.Mkdir("/proc/etc", 0755); err != nil { + Fatalf("error creating /proc/etc: %v", err) } // This cannot use SafeMount because there's no available procfs. But we // know that /proc is an empty tmpfs mount, so this is safe. if err := unix.Mount("runsc-proc", "/proc/proc", "proc", flags|unix.MS_RDONLY, ""); err != nil { Fatalf("error mounting proc: %v", err) } + if err := copyFile("/proc/etc/localtime", "/etc/localtime"); err != nil { + log.Warningf("Failed to copy /etc/localtime: %v. UTC timezone will be used.", err) + } root = "/proc/root" procPath = "/proc/proc" } @@ -409,7 +415,7 @@ func resolveMounts(conf *config.Config, mounts []specs.Mount, root string) ([]sp panic(fmt.Sprintf("%q could not be made relative to %q: %v", dst, root, err)) } - opts, err := adjustMountOptions(filepath.Join(root, relDst), m.Options) + opts, err := adjustMountOptions(conf, filepath.Join(root, relDst), m.Options) if err != nil { return nil, err } @@ -475,7 +481,7 @@ func resolveSymlinksImpl(root, base, rel string, followCount uint) (string, erro } // adjustMountOptions adds 'overlayfs_stale_read' if mounting over overlayfs. -func adjustMountOptions(path string, opts []string) ([]string, error) { +func adjustMountOptions(conf *config.Config, path string, opts []string) ([]string, error) { rv := make([]string, len(opts)) copy(rv, opts) diff --git a/runsc/container/container.go b/runsc/container/container.go index 6a9a07afe..d1f979eb2 100644 --- a/runsc/container/container.go +++ b/runsc/container/container.go @@ -646,6 +646,12 @@ func (c *Container) Resume() error { return c.saveLocked() } +// Cat prints out the content of the files. +func (c *Container) Cat(files []string, out *os.File) error { + log.Debugf("Cat in container, cid: %s, files: %+v", c.ID, files) + return c.Sandbox.Cat(c.ID, files, out) +} + // State returns the metadata of the container. func (c *Container) State() specs.State { return specs.State{ diff --git a/runsc/container/container_test.go b/runsc/container/container_test.go index 5fb4a3672..960c36946 100644 --- a/runsc/container/container_test.go +++ b/runsc/container/container_test.go @@ -442,6 +442,11 @@ func configs(t *testing.T, opts ...configOption) map[string]*config.Config { return all } +// sleepSpec generates a spec with sleep 1000 and a conf. +func sleepSpecConf(t *testing.T) (*specs.Spec, *config.Config) { + return testutil.NewSpecWithArgs("sleep", "1000"), testutil.TestConfig(t) +} + // TestLifecycle tests the basic Create/Start/Signal/Destroy container lifecycle. // It verifies after each step that the container can be loaded from disk, and // has the correct status. @@ -455,7 +460,7 @@ func TestLifecycle(t *testing.T) { t.Run(name, func(t *testing.T) { // The container will just sleep for a long time. We will kill it before // it finishes sleeping. - spec := testutil.NewSpecWithArgs("sleep", "100") + spec, _ := sleepSpecConf(t) rootDir, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) if err != nil { @@ -903,7 +908,7 @@ func TestExecProcList(t *testing.T) { for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { const uid = 343 - spec := testutil.NewSpecWithArgs("sleep", "100") + spec, _ := sleepSpecConf(t) _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) if err != nil { @@ -1422,8 +1427,7 @@ func TestPauseResume(t *testing.T) { // with calls to pause and resume and that pausing and resuming only // occurs given the correct state. func TestPauseResumeStatus(t *testing.T) { - spec := testutil.NewSpecWithArgs("sleep", "20") - conf := testutil.TestConfig(t) + spec, conf := sleepSpecConf(t) _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) if err != nil { t.Fatalf("error setting up container: %v", err) @@ -1490,7 +1494,7 @@ func TestCapabilities(t *testing.T) { for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { - spec := testutil.NewSpecWithArgs("sleep", "100") + spec, _ := sleepSpecConf(t) rootDir, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) if err != nil { t.Fatalf("error setting up container: %v", err) @@ -1640,7 +1644,7 @@ func TestMountNewDir(t *testing.T) { func TestReadonlyRoot(t *testing.T) { for name, conf := range configs(t, all...) { t.Run(name, func(t *testing.T) { - spec := testutil.NewSpecWithArgs("sleep", "100") + spec, _ := sleepSpecConf(t) spec.Root.Readonly = true _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) @@ -1692,7 +1696,7 @@ func TestReadonlyMount(t *testing.T) { if err != nil { t.Fatalf("ioutil.TempDir() failed: %v", err) } - spec := testutil.NewSpecWithArgs("sleep", "100") + spec, _ := sleepSpecConf(t) spec.Mounts = append(spec.Mounts, specs.Mount{ Destination: dir, Source: dir, @@ -1852,7 +1856,7 @@ func doAbbreviatedIDsTest(t *testing.T, vfs2 bool) { "baz-" + testutil.RandomContainerID(), } for _, cid := range cids { - spec := testutil.NewSpecWithArgs("sleep", "100") + spec, _ := sleepSpecConf(t) bundleDir, cleanup, err := testutil.SetupBundleDir(spec) if err != nil { t.Fatalf("error setting up container: %v", err) @@ -2229,7 +2233,7 @@ func TestMountPropagation(t *testing.T) { t.Fatalf("mount(%q, MS_SHARED): %v", srcMnt, err) } - spec := testutil.NewSpecWithArgs("sleep", "1000") + spec, conf := sleepSpecConf(t) priv := filepath.Join(tmpDir, "priv") slave := filepath.Join(tmpDir, "slave") @@ -2248,7 +2252,6 @@ func TestMountPropagation(t *testing.T) { }, } - conf := testutil.TestConfig(t) _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) if err != nil { t.Fatalf("error setting up container: %v", err) @@ -2563,12 +2566,11 @@ func TestRlimits(t *testing.T) { // TestRlimitsExec sets limit to number of open files and checks that the limit // is propagated to exec'd processes. func TestRlimitsExec(t *testing.T) { - spec := testutil.NewSpecWithArgs("sleep", "100") + spec, conf := sleepSpecConf(t) spec.Process.Rlimits = []specs.POSIXRlimit{ {Type: "RLIMIT_NOFILE", Hard: 1000, Soft: 100}, } - conf := testutil.TestConfig(t) _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) if err != nil { t.Fatalf("error setting up container: %v", err) @@ -2597,3 +2599,59 @@ func TestRlimitsExec(t *testing.T) { t.Errorf("ulimit result, got: %q, want: %q", got, want) } } + +// TestCat creates a file and checks that cat generates the expected output. +func TestCat(t *testing.T) { + f, err := ioutil.TempFile(testutil.TmpDir(), "test-case") + if err != nil { + t.Fatalf("ioutil.TempFile failed: %v", err) + } + defer os.RemoveAll(f.Name()) + + content := "test-cat" + if _, err := f.WriteString(content); err != nil { + t.Fatalf("f.WriteString(): %v", err) + } + f.Close() + + spec, conf := sleepSpecConf(t) + + _, bundleDir, cleanup, err := testutil.SetupContainer(spec, conf) + if err != nil { + t.Fatalf("error setting up container: %v", err) + } + defer cleanup() + + args := Args{ + ID: testutil.RandomContainerID(), + Spec: spec, + BundleDir: bundleDir, + } + + cont, err := New(conf, args) + if err != nil { + t.Fatalf("Creating container: %v", err) + } + defer cont.Destroy() + + if err := cont.Start(conf); err != nil { + t.Fatalf("starting container: %v", err) + } + + r, w, err := os.Pipe() + if err != nil { + t.Fatalf("os.Create(): %v", err) + } + + if err := cont.Cat([]string{f.Name()}, w); err != nil { + t.Fatalf("error cat from container: %v", err) + } + + buf := make([]byte, 1024) + if _, err := r.Read(buf); err != nil { + t.Fatalf("Read out: %v", err) + } + if got, want := string(buf), content; !strings.Contains(got, want) { + t.Errorf("out got %s, want include %s", buf, want) + } +} diff --git a/runsc/sandbox/sandbox.go b/runsc/sandbox/sandbox.go index 5fb7dc834..b15572a98 100644 --- a/runsc/sandbox/sandbox.go +++ b/runsc/sandbox/sandbox.go @@ -981,7 +981,7 @@ func (s *Sandbox) Pause(cid string) error { } defer conn.Close() - if err := conn.Call(boot.ContMgrPause, nil, nil); err != nil { + if err := conn.Call(boot.LifecyclePause, nil, nil); err != nil { return fmt.Errorf("pausing container %q: %v", cid, err) } return nil @@ -996,12 +996,30 @@ func (s *Sandbox) Resume(cid string) error { } defer conn.Close() - if err := conn.Call(boot.ContMgrResume, nil, nil); err != nil { + if err := conn.Call(boot.LifecycleResume, nil, nil); err != nil { return fmt.Errorf("resuming container %q: %v", cid, err) } return nil } +// Cat sends the cat call for a container in the sandbox. +func (s *Sandbox) Cat(cid string, files []string, out *os.File) error { + log.Debugf("Cat sandbox %q", s.ID) + conn, err := s.sandboxConnect() + if err != nil { + return err + } + defer conn.Close() + + if err := conn.Call(boot.FsCat, &control.CatOpts{ + Files: files, + FilePayload: urpc.FilePayload{Files: []*os.File{out}}, + }, nil); err != nil { + return fmt.Errorf("Cat container %q: %v", cid, err) + } + return nil +} + // IsRunning returns true if the sandbox or gofer process is running. func (s *Sandbox) IsRunning() bool { if s.Pid != 0 { diff --git a/test/e2e/integration_test.go b/test/e2e/integration_test.go index 9e22c9a7d..d41139944 100644 --- a/test/e2e/integration_test.go +++ b/test/e2e/integration_test.go @@ -742,3 +742,49 @@ func TestUnmount(t *testing.T) { t.Fatalf("docker run failed: %v", err) } } + +func TestDeleteInterface(t *testing.T) { + if testutil.IsRunningWithHostNet() { + t.Skip("not able to remove interfaces on hostnet") + } + + ctx := context.Background() + d := dockerutil.MakeContainer(ctx, t) + defer d.CleanUp(ctx) + + opts := dockerutil.RunOpts{ + Image: "basic/alpine", + CapAdd: []string{"NET_ADMIN"}, + } + if err := d.Spawn(ctx, opts, "sleep", "1000"); err != nil { + t.Fatalf("docker run failed: %v", err) + } + + // We should be able to remove eth0. + output, err := d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "ip link del dev eth0") + if err != nil { + t.Fatalf("failed to remove eth0: %s, output: %s", err, output) + } + // Verify that eth0 is no longer there. + output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "ip link show") + if err != nil { + t.Fatalf("docker exec ip link show failed: %s, output: %s", err, output) + } + if strings.Contains(output, "eth0") { + t.Fatalf("failed to remove eth0") + } + + // Loopback device can't be removed. + output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "ip link del dev lo") + if err == nil { + t.Fatalf("should not remove the loopback device: %v", output) + } + // Verify that lo is still there. + output, err = d.Exec(ctx, dockerutil.ExecOpts{}, "/bin/sh", "-c", "ip link show") + if err != nil { + t.Fatalf("docker exec ip link show failed: %s, output: %s", err, output) + } + if !strings.Contains(output, "lo") { + t.Fatalf("loopback interface is removed") + } +} diff --git a/test/perf/BUILD b/test/perf/BUILD index 97ca0e75a..58fe333ee 100644 --- a/test/perf/BUILD +++ b/test/perf/BUILD @@ -70,6 +70,13 @@ syscall_test( ) syscall_test( + size = "large", + add_overlay = True, + debug = False, + test = "//test/perf/linux:dup_benchmark", +) + +syscall_test( debug = False, test = "//test/perf/linux:pipe_benchmark", ) @@ -146,3 +153,24 @@ syscall_test( test = "//test/perf/linux:verity_open_benchmark", vfs1 = False, ) + +syscall_test( + size = "large", + debug = False, + test = "//test/perf/linux:verity_read_benchmark", + vfs1 = False, +) + +syscall_test( + size = "large", + debug = False, + test = "//test/perf/linux:verity_randread_benchmark", + vfs1 = False, +) + +syscall_test( + size = "large", + debug = False, + test = "//test/perf/linux:verity_open_read_close_benchmark", + vfs1 = False, +) diff --git a/test/perf/linux/BUILD b/test/perf/linux/BUILD index b4f192227..61ed98ff5 100644 --- a/test/perf/linux/BUILD +++ b/test/perf/linux/BUILD @@ -109,6 +109,22 @@ cc_binary( ) cc_binary( + name = "dup_benchmark", + testonly = 1, + srcs = [ + "dup_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:fs_util", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + ], +) + +cc_binary( name = "read_benchmark", testonly = 1, srcs = [ @@ -389,3 +405,60 @@ cc_binary( "//test/util:verity_util", ], ) + +cc_binary( + name = "verity_read_benchmark", + testonly = 1, + srcs = [ + "verity_read_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:capability_util", + "//test/util:fs_util", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:verity_util", + ], +) + +cc_binary( + name = "verity_randread_benchmark", + testonly = 1, + srcs = [ + "verity_randread_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:capability_util", + "//test/util:fs_util", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:verity_util", + ], +) + +cc_binary( + name = "verity_open_read_close_benchmark", + testonly = 1, + srcs = [ + "verity_open_read_close_benchmark.cc", + ], + deps = [ + gbenchmark, + gtest, + "//test/util:capability_util", + "//test/util:fs_util", + "//test/util:logging", + "//test/util:temp_path", + "//test/util:test_main", + "//test/util:test_util", + "//test/util:verity_util", + ], +) diff --git a/test/perf/linux/dup_benchmark.cc b/test/perf/linux/dup_benchmark.cc new file mode 100644 index 000000000..5d808d225 --- /dev/null +++ b/test/perf/linux/dup_benchmark.cc @@ -0,0 +1,55 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <unistd.h> + +#include <memory> +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/fs_util.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" + +namespace gvisor { + +namespace { + +void BM_Dup(benchmark::State& state) { + const int size = state.range(0); + + for (auto _ : state) { + std::vector<int> v; + for (int i = 0; i < size; i++) { + int fd = dup(2); + TEST_CHECK(fd != -1); + v.push_back(fd); + } + for (int i = 0; i < size; i++) { + int fd = v[i]; + close(fd); + } + } + state.SetItemsProcessed(state.iterations() * size); +} + +BENCHMARK(BM_Dup)->Range(1, 1 << 15)->UseRealTime(); + +} // namespace + +} // namespace gvisor diff --git a/test/perf/linux/verity_open_benchmark.cc b/test/perf/linux/verity_open_benchmark.cc index ce53a2100..026b6f101 100644 --- a/test/perf/linux/verity_open_benchmark.cc +++ b/test/perf/linux/verity_open_benchmark.cc @@ -36,8 +36,6 @@ namespace testing { namespace { void BM_Open(benchmark::State& state) { - SKIP_IF(IsRunningWithVFS1()); - const int size = state.range(0); std::vector<TempPath> cache; std::vector<EnableTarget> targets; diff --git a/test/perf/linux/verity_open_read_close_benchmark.cc b/test/perf/linux/verity_open_read_close_benchmark.cc new file mode 100644 index 000000000..e77577f22 --- /dev/null +++ b/test/perf/linux/verity_open_read_close_benchmark.cc @@ -0,0 +1,75 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <sys/mount.h> +#include <unistd.h> + +#include <memory> +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/capability_util.h" +#include "test/util/fs_util.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" +#include "test/util/verity_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_VerityOpenReadClose(benchmark::State& state) { + const int size = state.range(0); + + // Mount a tmpfs file system to be wrapped by a verity fs. + TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TEST_CHECK(mount("", dir.path().c_str(), "tmpfs", 0, "") == 0); + + std::vector<TempPath> cache; + std::vector<EnableTarget> targets; + + for (int i = 0; i < size; i++) { + auto file = ASSERT_NO_ERRNO_AND_VALUE( + TempPath::CreateFileWith(dir.path(), "some contents", 0644)); + targets.emplace_back( + EnableTarget(std::string(Basename(file.path())), O_RDONLY)); + cache.emplace_back(std::move(file)); + } + + std::string verity_dir = + TEST_CHECK_NO_ERRNO_AND_VALUE(MountVerity(dir.path(), targets)); + + char buf[1]; + unsigned int seed = 1; + for (auto _ : state) { + const int chosen = rand_r(&seed) % size; + int fd = open(JoinPath(verity_dir, targets[chosen].path).c_str(), O_RDONLY); + TEST_CHECK(fd != -1); + TEST_CHECK(read(fd, buf, 1) == 1); + close(fd); + } +} + +BENCHMARK(BM_VerityOpenReadClose)->Range(1000, 16384)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/verity_randread_benchmark.cc b/test/perf/linux/verity_randread_benchmark.cc new file mode 100644 index 000000000..4178cfad8 --- /dev/null +++ b/test/perf/linux/verity_randread_benchmark.cc @@ -0,0 +1,108 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <sys/mount.h> +#include <sys/stat.h> +#include <sys/uio.h> +#include <unistd.h> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" +#include "test/util/verity_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +// Create a 1GB file that will be read from at random positions. This should +// invalid any performance gains from caching. +const uint64_t kFileSize = Megabytes(1024); + +// How many bytes to write at once to initialize the file used to read from. +const uint32_t kWriteSize = 65536; + +// Largest benchmarked read unit. +const uint32_t kMaxRead = Megabytes(64); + +// Global test state, initialized once per process lifetime. +struct GlobalState { + explicit GlobalState() { + // Mount a tmpfs file system to be wrapped by a verity fs. + tmp_dir_ = TempPath::CreateDir().ValueOrDie(); + TEST_CHECK(mount("", tmp_dir_.path().c_str(), "tmpfs", 0, "") == 0); + file_ = TempPath::CreateFileIn(tmp_dir_.path()).ValueOrDie(); + filename_ = std::string(Basename(file_.path())); + + FileDescriptor fd = Open(file_.path(), O_WRONLY).ValueOrDie(); + + // Try to minimize syscalls by using maximum size writev() requests. + std::vector<char> buffer(kWriteSize); + RandomizeBuffer(buffer.data(), buffer.size()); + const std::vector<std::vector<struct iovec>> iovecs_list = + GenerateIovecs(kFileSize + kMaxRead, buffer.data(), buffer.size()); + for (const auto& iovecs : iovecs_list) { + TEST_CHECK(writev(fd.get(), iovecs.data(), iovecs.size()) >= 0); + } + verity_dir_ = + MountVerity(tmp_dir_.path(), {EnableTarget(filename_, O_RDONLY)}) + .ValueOrDie(); + } + TempPath tmp_dir_; + TempPath file_; + std::string verity_dir_; + std::string filename_; +}; + +GlobalState& GetGlobalState() { + // This gets created only once throughout the lifetime of the process. + // Use a dynamically allocated object (that is never deleted) to avoid order + // of destruction of static storage variables issues. + static GlobalState* const state = + // The actual file size is the maximum random seek range (kFileSize) + the + // maximum read size so we can read that number of bytes at the end of the + // file. + new GlobalState(); + return *state; +} + +void BM_VerityRandRead(benchmark::State& state) { + const int size = state.range(0); + + GlobalState& global_state = GetGlobalState(); + FileDescriptor verity_fd = ASSERT_NO_ERRNO_AND_VALUE(Open( + JoinPath(global_state.verity_dir_, global_state.filename_), O_RDONLY)); + std::vector<char> buf(size); + + unsigned int seed = 1; + for (auto _ : state) { + TEST_CHECK(PreadFd(verity_fd.get(), buf.data(), buf.size(), + rand_r(&seed) % kFileSize) == size); + } + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_VerityRandRead)->Range(1, kMaxRead)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/perf/linux/verity_read_benchmark.cc b/test/perf/linux/verity_read_benchmark.cc new file mode 100644 index 000000000..738b5ba45 --- /dev/null +++ b/test/perf/linux/verity_read_benchmark.cc @@ -0,0 +1,69 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include <fcntl.h> +#include <stdlib.h> +#include <sys/mount.h> +#include <unistd.h> + +#include <memory> +#include <string> +#include <vector> + +#include "gtest/gtest.h" +#include "benchmark/benchmark.h" +#include "test/util/capability_util.h" +#include "test/util/fs_util.h" +#include "test/util/logging.h" +#include "test/util/temp_path.h" +#include "test/util/test_util.h" +#include "test/util/verity_util.h" + +namespace gvisor { +namespace testing { + +namespace { + +void BM_VerityRead(benchmark::State& state) { + const int size = state.range(0); + const std::string contents(size, 0); + + // Mount a tmpfs file system to be wrapped by a verity fs. + TempPath dir = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateDir()); + TEST_CHECK(mount("", dir.path().c_str(), "tmpfs", 0, "") == 0); + + auto path = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileWith( + dir.path(), contents, TempPath::kDefaultFileMode)); + std::string filename = std::string(Basename(path.path())); + + std::string verity_dir = TEST_CHECK_NO_ERRNO_AND_VALUE( + MountVerity(dir.path(), {EnableTarget(filename, O_RDONLY)})); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(Open(JoinPath(verity_dir, filename), O_RDONLY)); + std::vector<char> buf(size); + for (auto _ : state) { + TEST_CHECK(PreadFd(fd.get(), buf.data(), buf.size(), 0) == size); + } + + state.SetBytesProcessed(static_cast<int64_t>(size) * + static_cast<int64_t>(state.iterations())); +} + +BENCHMARK(BM_VerityRead)->Range(1, 1 << 26)->UseRealTime(); + +} // namespace + +} // namespace testing +} // namespace gvisor diff --git a/test/root/chroot_test.go b/test/root/chroot_test.go index 58fcd6f08..5114a9602 100644 --- a/test/root/chroot_test.go +++ b/test/root/chroot_test.go @@ -68,13 +68,15 @@ func TestChroot(t *testing.T) { if err != nil { t.Fatalf("error listing %q: %v", chroot, err) } - if want, got := 1, len(fi); want != got { + if want, got := 2, len(fi); want != got { t.Fatalf("chroot dir got %d entries, want %d", got, want) } - // chroot dir is prepared by runsc and should contains only /proc. - if fi[0].Name() != "proc" { - t.Errorf("chroot got children %v, want %v", fi[0].Name(), "proc") + // chroot dir is prepared by runsc and should contains only /etc and /proc. + for i, want := range []string{"etc", "proc"} { + if got := fi[i].Name(); got != want { + t.Errorf("chroot got child %v, want %v", got, want) + } } d.CleanUp(ctx) diff --git a/test/syscalls/linux/BUILD b/test/syscalls/linux/BUILD index 7185df076..7129a797b 100644 --- a/test/syscalls/linux/BUILD +++ b/test/syscalls/linux/BUILD @@ -560,6 +560,7 @@ cc_binary( deps = [ "//test/util:eventfd_util", "//test/util:file_descriptor", + "@com_google_absl//absl/memory", gtest, "//test/util:fs_util", "//test/util:posix_error", @@ -1750,6 +1751,7 @@ cc_binary( "//test/util:mount_util", "@com_google_absl//absl/container:node_hash_set", "@com_google_absl//absl/strings", + "@com_google_absl//absl/strings:str_format", "@com_google_absl//absl/synchronization", "@com_google_absl//absl/time", gtest, @@ -4170,9 +4172,11 @@ cc_binary( srcs = ["msgqueue.cc"], linkstatic = 1, deps = [ + "//test/util:capability_util", "//test/util:temp_path", "//test/util:test_main", "//test/util:test_util", + "@com_google_absl//absl/time", ], ) diff --git a/test/syscalls/linux/dup.cc b/test/syscalls/linux/dup.cc index ba4e13fb9..8f0974f45 100644 --- a/test/syscalls/linux/dup.cc +++ b/test/syscalls/linux/dup.cc @@ -13,9 +13,11 @@ // limitations under the License. #include <fcntl.h> +#include <sys/resource.h> #include <unistd.h> #include "gtest/gtest.h" +#include "absl/memory/memory.h" #include "test/util/eventfd_util.h" #include "test/util/file_descriptor.h" #include "test/util/fs_util.h" @@ -98,6 +100,45 @@ TEST(DupTest, Dup2) { ASSERT_NO_ERRNO(CheckSameFile(fd, nfd2)); } +TEST(DupTest, Rlimit) { + auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); + FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY)); + + struct rlimit rl = {}; + EXPECT_THAT(getrlimit(RLIMIT_NOFILE, &rl), SyscallSucceeds()); + + ASSERT_THAT(setrlimit(RLIMIT_NOFILE, &rl), SyscallSucceeds()); + + constexpr int kFDLimit = 101; + // Create a file descriptor that will be above the limit. + FileDescriptor aboveLimitFD = + ASSERT_NO_ERRNO_AND_VALUE(Dup2(fd, kFDLimit * 2 - 1)); + + rl.rlim_cur = kFDLimit; + ASSERT_THAT(setrlimit(RLIMIT_NOFILE, &rl), SyscallSucceeds()); + ASSERT_THAT(dup3(fd.get(), kFDLimit, 0), SyscallFails()); + + std::vector<std::unique_ptr<FileDescriptor>> fds; + int prev_fd = fd.get(); + int used_fds = 0; + for (int i = 0; i < kFDLimit; ++i) { + int new_fd = dup(fd.get()); + if (new_fd == -1) { + break; + } + auto f = absl::make_unique<FileDescriptor>(new_fd); + EXPECT_LT(new_fd, kFDLimit); + EXPECT_GT(new_fd, prev_fd); + // Check that all fds in (prev_fd, new_fd) are used. + for (int j = prev_fd + 1; j < new_fd; ++j) { + if (fcntl(j, F_GETFD) != -1) used_fds++; + } + prev_fd = new_fd; + fds.push_back(std::move(f)); + } + EXPECT_EQ(fds.size() + used_fds, kFDLimit - fd.get() - 1); +} + TEST(DupTest, Dup2SameFD) { auto f = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(f.path(), O_RDONLY)); diff --git a/test/syscalls/linux/lseek.cc b/test/syscalls/linux/lseek.cc index d4f89527c..dbc21833f 100644 --- a/test/syscalls/linux/lseek.cc +++ b/test/syscalls/linux/lseek.cc @@ -121,7 +121,8 @@ TEST(LseekTest, InvalidFD) { } TEST(LseekTest, DirCurEnd) { - const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open("/tmp", O_RDONLY)); + const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE( + Open(GetAbsoluteTestTmpdir().c_str(), O_RDONLY)); ASSERT_THAT(lseek(fd.get(), 0, SEEK_CUR), SyscallSucceedsWithValue(0)); } diff --git a/test/syscalls/linux/memfd.cc b/test/syscalls/linux/memfd.cc index 4a450742b..dbd1c93ae 100644 --- a/test/syscalls/linux/memfd.cc +++ b/test/syscalls/linux/memfd.cc @@ -445,9 +445,10 @@ TEST(MemfdTest, SealsAreInodeLevelProperties) { // Tmpfs files also support seals, but are created with F_SEAL_SEAL. TEST(MemfdTest, TmpfsFilesHaveSealSeal) { - SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs("/tmp"))); + std::string tmpdir = GetAbsoluteTestTmpdir(); + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(IsTmpfs(tmpdir.c_str()))); const TempPath tmpfs_file = - ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn("/tmp")); + ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFileIn(tmpdir.c_str())); const FileDescriptor fd = ASSERT_NO_ERRNO_AND_VALUE(Open(tmpfs_file.path(), O_RDWR, 0644)); EXPECT_THAT(fcntl(fd.get(), F_GET_SEALS), diff --git a/test/syscalls/linux/msgqueue.cc b/test/syscalls/linux/msgqueue.cc index 2409de7e8..837e913d9 100644 --- a/test/syscalls/linux/msgqueue.cc +++ b/test/syscalls/linux/msgqueue.cc @@ -12,10 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +#include <errno.h> #include <sys/ipc.h> #include <sys/msg.h> #include <sys/types.h> +#include "absl/time/clock.h" +#include "test/util/capability_util.h" #include "test/util/temp_path.h" #include "test/util/test_util.h" @@ -23,6 +26,10 @@ namespace gvisor { namespace testing { namespace { +constexpr int msgMax = 8192; // Max size for message in bytes. +constexpr int msgMni = 32000; // Max number of identifiers. +constexpr int msgMnb = 16384; // Default max size of message queue in bytes. + // Queue is a RAII class used to automatically clean message queues. class Queue { public: @@ -46,6 +53,25 @@ class Queue { int id_ = -1; }; +// Default size for messages. +constexpr size_t msgSize = 50; + +// msgbuf is a simple buffer using to send and receive text messages for +// testing purposes. +struct msgbuf { + int64_t mtype; + char mtext[msgSize]; +}; + +bool operator==(msgbuf& a, msgbuf& b) { + for (size_t i = 0; i < msgSize; i++) { + if (a.mtext[i] != b.mtext[i]) { + return false; + } + } + return a.mtype == b.mtype; +} + // Test simple creation and retrieval for msgget(2). TEST(MsgqueueTest, MsgGet) { const TempPath keyfile = ASSERT_NO_ERRNO_AND_VALUE(TempPath::CreateFile()); @@ -82,6 +108,552 @@ TEST(MsgqueueTest, MsgGetIpcPrivate) { EXPECT_NE(queue1.get(), queue2.get()); } +// Test simple msgsnd and msgrcv. +TEST(MsgqueueTest, MsgOpSimple) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, "A message."}; + msgbuf rcv; + + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, 0, 0), + SyscallSucceedsWithValue(sizeof(buf.mtext))); + EXPECT_TRUE(buf == rcv); +} + +// Test msgsnd and msgrcv of an empty message. +TEST(MsgqueueTest, MsgOpEmpty) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, ""}; + msgbuf rcv; + + ASSERT_THAT(msgsnd(queue.get(), &buf, 0, 0), SyscallSucceeds()); + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, 0, 0), + SyscallSucceedsWithValue(0)); +} + +// Test truncation of message with MSG_NOERROR flag. +TEST(MsgqueueTest, MsgOpTruncate) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, ""}; + msgbuf rcv; + + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) - 1, 0, MSG_NOERROR), + SyscallSucceedsWithValue(sizeof(buf.mtext) - 1)); +} + +// Test msgsnd and msgrcv using invalid arguments. +TEST(MsgqueueTest, MsgOpInvalidArgs) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, ""}; + + EXPECT_THAT(msgsnd(-1, &buf, 0, 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(msgsnd(queue.get(), &buf, -1, 0), SyscallFailsWithErrno(EINVAL)); + + buf.mtype = -1; + EXPECT_THAT(msgsnd(queue.get(), &buf, 1, 0), SyscallFailsWithErrno(EINVAL)); + + EXPECT_THAT(msgrcv(-1, &buf, 1, 0, 0), SyscallFailsWithErrno(EINVAL)); + EXPECT_THAT(msgrcv(queue.get(), &buf, -1, 0, 0), + SyscallFailsWithErrno(EINVAL)); +} + +// Test non-blocking msgrcv with an empty queue. +TEST(MsgqueueTest, MsgOpNoMsg) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf rcv; + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(rcv.mtext) + 1, 0, IPC_NOWAIT), + SyscallFailsWithErrno(ENOMSG)); +} + +// Test non-blocking msgrcv with a non-empty queue, but no messages of wanted +// type. +TEST(MsgqueueTest, MsgOpNoMsgType) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, ""}; + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + + EXPECT_THAT(msgrcv(queue.get(), &buf, sizeof(buf.mtext) + 1, 2, IPC_NOWAIT), + SyscallFailsWithErrno(ENOMSG)); +} + +// Test msgrcv with a larger size message than wanted, and truncation disabled. +TEST(MsgqueueTest, MsgOpTooBig) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, ""}; + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + + EXPECT_THAT(msgrcv(queue.get(), &buf, sizeof(buf.mtext) - 1, 0, 0), + SyscallFailsWithErrno(E2BIG)); +} + +// Test receiving messages based on type. +TEST(MsgqueueTest, MsgRcvType) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + // Send messages in an order and receive them in reverse, based on type, + // which shouldn't block. + std::map<int64_t, msgbuf> typeToBuf = { + {1, msgbuf{1, "Message 1."}}, {2, msgbuf{2, "Message 2."}}, + {3, msgbuf{3, "Message 3."}}, {4, msgbuf{4, "Message 4."}}, + {5, msgbuf{5, "Message 5."}}, {6, msgbuf{6, "Message 6."}}, + {7, msgbuf{7, "Message 7."}}, {8, msgbuf{8, "Message 8."}}, + {9, msgbuf{9, "Message 9."}}}; + + for (auto const& [type, buf] : typeToBuf) { + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + } + + for (int64_t i = typeToBuf.size(); i > 0; i--) { + msgbuf rcv; + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(typeToBuf[i].mtext) + 1, i, 0), + SyscallSucceedsWithValue(sizeof(typeToBuf[i].mtext))); + EXPECT_TRUE(typeToBuf[i] == rcv); + } +} + +// Test using MSG_EXCEPT to receive a different-type message. +TEST(MsgqueueTest, MsgExcept) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + std::map<int64_t, msgbuf> typeToBuf = { + {1, msgbuf{1, "Message 1."}}, + {2, msgbuf{2, "Message 2."}}, + }; + + for (auto const& [type, buf] : typeToBuf) { + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + } + + for (int64_t i = typeToBuf.size(); i > 0; i--) { + msgbuf actual = typeToBuf[i == 1 ? 2 : 1]; + msgbuf rcv; + + EXPECT_THAT( + msgrcv(queue.get(), &rcv, sizeof(actual.mtext) + 1, i, MSG_EXCEPT), + SyscallSucceedsWithValue(sizeof(actual.mtext))); + EXPECT_TRUE(actual == rcv); + } +} + +// Test msgrcv with a negative type. +TEST(MsgqueueTest, MsgRcvTypeNegative) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + // When msgtyp is negative, msgrcv returns the first message with mtype less + // than or equal to the absolute value. + msgbuf buf{2, "A message."}; + msgbuf rcv; + + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + + // Nothing is less than or equal to 1. + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, -1, IPC_NOWAIT), + SyscallFailsWithErrno(ENOMSG)); + + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, -3, 0), + SyscallSucceedsWithValue(sizeof(buf.mtext))); + EXPECT_TRUE(buf == rcv); +} + +// Test permission-related failure scenarios. +TEST(MsgqueueTest, MsgOpPermissions) { + AutoCapability cap(CAP_IPC_OWNER, false); + + Queue queue(msgget(IPC_PRIVATE, 0000)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, ""}; + + EXPECT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallFailsWithErrno(EACCES)); + EXPECT_THAT(msgrcv(queue.get(), &buf, sizeof(buf.mtext), 0, 0), + SyscallFailsWithErrno(EACCES)); +} + +// Test limits for messages and queues. +TEST(MsgqueueTest, MsgOpLimits) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, "A message."}; + + // Limit for one message. + EXPECT_THAT(msgsnd(queue.get(), &buf, msgMax + 1, 0), + SyscallFailsWithErrno(EINVAL)); + + // Limit for queue. + // Use a buffer with the maximum mount of bytes that can be transformed to + // make it easier to exhaust the queue limit. + struct msgmax { + int64_t mtype; + char mtext[msgMax]; + }; + + msgmax limit{1, ""}; + for (size_t i = 0, msgCount = msgMnb / msgMax; i < msgCount; i++) { + EXPECT_THAT(msgsnd(queue.get(), &limit, sizeof(limit.mtext), 0), + SyscallSucceeds()); + } + EXPECT_THAT(msgsnd(queue.get(), &limit, sizeof(limit.mtext), IPC_NOWAIT), + SyscallFailsWithErrno(EAGAIN)); +} + +// MsgCopySupported returns true if MSG_COPY is supported. +bool MsgCopySupported() { + // msgrcv(2) man page states that MSG_COPY flag is available only if the + // kernel was built with the CONFIG_CHECKPOINT_RESTORE option. If MSG_COPY + // is used when the kernel was configured without the option, msgrcv produces + // a ENOSYS error. + // To avoid test failure, we perform a small test using msgrcv, and skip the + // test if errno == ENOSYS. This means that the test will always run on + // gVisor, but may be skipped on native linux. + + Queue queue(msgget(IPC_PRIVATE, 0600)); + + msgbuf buf{1, "Test message."}; + msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0); + + return !(msgrcv(queue.get(), &buf, sizeof(buf.mtext) + 1, 0, + MSG_COPY | IPC_NOWAIT) == -1 && + errno == ENOSYS); +} + +// Test msgrcv using MSG_COPY. +TEST(MsgqueueTest, MsgCopy) { + SKIP_IF(!MsgCopySupported()); + + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf bufs[5] = { + msgbuf{1, "Message 1."}, msgbuf{2, "Message 2."}, msgbuf{3, "Message 3."}, + msgbuf{4, "Message 4."}, msgbuf{5, "Message 5."}, + }; + + for (auto& buf : bufs) { + ASSERT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + } + + // Receive a copy of the messages. + for (size_t i = 0, size = sizeof(bufs) / sizeof(bufs[0]); i < size; i++) { + msgbuf buf = bufs[i]; + msgbuf rcv; + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, i, + MSG_COPY | IPC_NOWAIT), + SyscallSucceedsWithValue(sizeof(buf.mtext))); + EXPECT_TRUE(buf == rcv); + } + + // Re-receive the messages normally. + for (auto& buf : bufs) { + msgbuf rcv; + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext) + 1, 0, 0), + SyscallSucceedsWithValue(sizeof(buf.mtext))); + EXPECT_TRUE(buf == rcv); + } +} + +// Test msgrcv using MSG_COPY with invalid arguments. +TEST(MsgqueueTest, MsgCopyInvalidArgs) { + SKIP_IF(!MsgCopySupported()); + + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf rcv; + EXPECT_THAT(msgrcv(queue.get(), &rcv, msgSize, 1, MSG_COPY), + SyscallFailsWithErrno(EINVAL)); + + EXPECT_THAT( + msgrcv(queue.get(), &rcv, msgSize, 5, MSG_COPY | MSG_EXCEPT | IPC_NOWAIT), + SyscallFailsWithErrno(EINVAL)); +} + +// Test msgrcv using MSG_COPY with invalid indices. +TEST(MsgqueueTest, MsgCopyInvalidIndex) { + SKIP_IF(!MsgCopySupported()); + + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf rcv; + EXPECT_THAT(msgrcv(queue.get(), &rcv, msgSize, -3, MSG_COPY | IPC_NOWAIT), + SyscallFailsWithErrno(ENOMSG)); + + EXPECT_THAT(msgrcv(queue.get(), &rcv, msgSize, 5, MSG_COPY | IPC_NOWAIT), + SyscallFailsWithErrno(ENOMSG)); +} + +// Test msgrcv (most probably) blocking on an empty queue. +TEST(MsgqueueTest, MsgRcvBlocking) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf buf{1, "A message."}; + + const pid_t child_pid = fork(); + if (child_pid == 0) { + msgbuf rcv; + TEST_PCHECK(RetryEINTR(msgrcv)(queue.get(), &rcv, sizeof(buf.mtext) + 1, 0, + 0) == sizeof(buf.mtext) && + buf == rcv); + _exit(0); + } + + // Sleep to try and make msgrcv block before sending a message. + absl::SleepFor(absl::Milliseconds(150)); + + EXPECT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); +} + +// Test msgrcv (most probably) waiting for a specific-type message. +TEST(MsgqueueTest, MsgRcvTypeBlocking) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + msgbuf bufs[5] = {{1, "A message."}, + {1, "A message."}, + {1, "A message."}, + {1, "A message."}, + {2, "A different message."}}; + + const pid_t child_pid = fork(); + if (child_pid == 0) { + msgbuf buf = bufs[4]; // Buffer that should be received. + msgbuf rcv; + TEST_PCHECK(RetryEINTR(msgrcv)(queue.get(), &rcv, sizeof(buf.mtext) + 1, 2, + 0) == sizeof(buf.mtext) && + buf == rcv); + _exit(0); + } + + // Sleep to try and make msgrcv block before sending messages. + absl::SleepFor(absl::Milliseconds(150)); + + // Send all buffers in order, only last one should be received. + for (auto& buf : bufs) { + EXPECT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + } + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); +} + +// Test msgsnd (most probably) blocking on a full queue. +TEST(MsgqueueTest, MsgSndBlocking) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + // Use a buffer with the maximum mount of bytes that can be transformed to + // make it easier to exhaust the queue limit. + struct msgmax { + int64_t mtype; + char mtext[msgMax]; + }; + + msgmax buf{1, ""}; // Has max amount of bytes. + + const size_t msgCount = msgMnb / msgMax; // Number of messages that can be + // sent without blocking. + + const pid_t child_pid = fork(); + if (child_pid == 0) { + // Fill the queue. + for (size_t i = 0; i < msgCount; i++) { + TEST_PCHECK(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0) == 0); + } + + // Next msgsnd should block. + TEST_PCHECK(RetryEINTR(msgsnd)(queue.get(), &buf, sizeof(buf.mtext), 0) == + 0); + _exit(0); + } + + // To increase the chance of the last msgsnd blocking before doing a msgrcv, + // we use MSG_COPY option to copy the last index in the queue. As long as + // MSG_COPY fails, the queue hasn't yet been filled. When MSG_COPY succeeds, + // the queue is filled, and most probably, a blocking msgsnd has been made. + msgmax rcv; + while (msgrcv(queue.get(), &rcv, msgMax, msgCount - 1, + MSG_COPY | IPC_NOWAIT) == -1 && + errno == ENOMSG) { + } + + // Delay a bit more for the blocking msgsnd. + absl::SleepFor(absl::Milliseconds(100)); + + EXPECT_THAT(msgrcv(queue.get(), &rcv, sizeof(buf.mtext), 0, 0), + SyscallSucceedsWithValue(sizeof(buf.mtext))); + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); +} + +// Test removing a queue while a blocking msgsnd is executing. +TEST(MsgqueueTest, MsgSndRmWhileBlocking) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + // Use a buffer with the maximum mount of bytes that can be transformed to + // make it easier to exhaust the queue limit. + struct msgmax { + int64_t mtype; + char mtext[msgMax]; + }; + + const size_t msgCount = msgMnb / msgMax; // Number of messages that can be + // sent without blocking. + const pid_t child_pid = fork(); + if (child_pid == 0) { + // Fill the queue. + msgmax buf{1, ""}; + for (size_t i = 0; i < msgCount; i++) { + EXPECT_THAT(msgsnd(queue.get(), &buf, sizeof(buf.mtext), 0), + SyscallSucceeds()); + } + + // Next msgsnd should block. Because we're repeating on EINTR, msgsnd may + // race with msgctl(IPC_RMID) and return EINVAL. + TEST_PCHECK(RetryEINTR(msgsnd)(queue.get(), &buf, sizeof(buf.mtext), 0) == + -1 && + (errno == EIDRM || errno == EINVAL)); + _exit(0); + } + + // Similar to MsgSndBlocking, we do this to increase the chance of msgsnd + // blocking before removing the queue. + msgmax rcv; + while (msgrcv(queue.get(), &rcv, msgMax, msgCount - 1, + MSG_COPY | IPC_NOWAIT) == -1 && + errno == ENOMSG) { + } + absl::SleepFor(absl::Milliseconds(100)); + + EXPECT_THAT(msgctl(queue.release(), IPC_RMID, nullptr), SyscallSucceeds()); + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); +} + +// Test removing a queue while a blocking msgrcv is executing. +TEST(MsgqueueTest, MsgRcvRmWhileBlocking) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + const pid_t child_pid = fork(); + if (child_pid == 0) { + // Because we're repeating on EINTR, msgsnd may race with msgctl(IPC_RMID) + // and return EINVAL. + msgbuf rcv; + TEST_PCHECK(RetryEINTR(msgrcv)(queue.get(), &rcv, 1, 2, 0) == -1 && + (errno == EIDRM || errno == EINVAL)); + _exit(0); + } + + // Sleep to try and make msgrcv block before sending messages. + absl::SleepFor(absl::Milliseconds(150)); + + EXPECT_THAT(msgctl(queue.release(), IPC_RMID, nullptr), SyscallSucceeds()); + + int status; + ASSERT_THAT(RetryEINTR(waitpid)(child_pid, &status, 0), + SyscallSucceedsWithValue(child_pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); +} + +// Test a collection of msgsnd/msgrcv operations in different processes. +TEST(MsgqueueTest, MsgOpGeneral) { + Queue queue(msgget(IPC_PRIVATE, 0600)); + ASSERT_THAT(queue.get(), SyscallSucceeds()); + + // Create 50 sending, and 50 receiving processes. There are only 5 messages to + // be sent and received, each with a different type. All messages will be sent + // and received equally (10 of each.) By the end of the test all processes + // should unblock and return normally. + const size_t msgCount = 5; + std::map<int64_t, msgbuf> typeToBuf = {{1, msgbuf{1, "Message 1."}}, + {2, msgbuf{2, "Message 2."}}, + {3, msgbuf{3, "Message 3."}}, + {4, msgbuf{4, "Message 4."}}, + {5, msgbuf{5, "Message 5."}}}; + + std::vector<pid_t> children; + + const size_t pCount = 50; + for (size_t i = 1; i <= pCount; i++) { + const pid_t child_pid = fork(); + if (child_pid == 0) { + msgbuf buf = typeToBuf[(i % msgCount) + 1]; + msgbuf rcv; + TEST_PCHECK(RetryEINTR(msgrcv)(queue.get(), &rcv, sizeof(buf.mtext) + 1, + (i % msgCount) + 1, + 0) == sizeof(buf.mtext) && + buf == rcv); + _exit(0); + } + children.push_back(child_pid); + } + + for (size_t i = 1; i <= pCount; i++) { + const pid_t child_pid = fork(); + if (child_pid == 0) { + msgbuf buf = typeToBuf[(i % msgCount) + 1]; + TEST_PCHECK(RetryEINTR(msgsnd)(queue.get(), &buf, sizeof(buf.mtext), 0) == + 0); + _exit(0); + } + children.push_back(child_pid); + } + + for (auto const& pid : children) { + int status; + ASSERT_THAT(RetryEINTR(waitpid)(pid, &status, 0), + SyscallSucceedsWithValue(pid)); + EXPECT_TRUE(WIFEXITED(status) && WEXITSTATUS(status) == 0); + } +} + } // namespace } // namespace testing } // namespace gvisor diff --git a/test/syscalls/linux/prctl.cc b/test/syscalls/linux/prctl.cc index 25b0e63d4..286b3d168 100644 --- a/test/syscalls/linux/prctl.cc +++ b/test/syscalls/linux/prctl.cc @@ -214,6 +214,12 @@ TEST(PrctlTest, RootDumpability) { SyscallFailsWithErrno(EINVAL)); } +TEST(PrctlTest, SetGetSubreaper) { + // Setting subreaper on PID 1 works vacuously because PID 1 is always a + // subreaper. + EXPECT_THAT(prctl(PR_SET_CHILD_SUBREAPER, 1), SyscallSucceeds()); +} + } // namespace } // namespace testing diff --git a/test/syscalls/linux/proc.cc b/test/syscalls/linux/proc.cc index 78aa73edc..8a4025fed 100644 --- a/test/syscalls/linux/proc.cc +++ b/test/syscalls/linux/proc.cc @@ -54,6 +54,8 @@ #include "absl/strings/match.h" #include "absl/strings/numbers.h" #include "absl/strings/str_cat.h" +#include "absl/strings/str_format.h" +#include "absl/strings/str_join.h" #include "absl/strings/str_split.h" #include "absl/strings/string_view.h" #include "absl/synchronization/mutex.h" @@ -88,6 +90,7 @@ using ::testing::Gt; using ::testing::HasSubstr; using ::testing::IsSupersetOf; using ::testing::Pair; +using ::testing::StartsWith; using ::testing::UnorderedElementsAre; using ::testing::UnorderedElementsAreArray; @@ -1622,10 +1625,41 @@ TEST(ProcPidStatusTest, HasBasicFields) { ASSERT_FALSE(status_str.empty()); const auto status = ASSERT_NO_ERRNO_AND_VALUE(ParseProcStatus(status_str)); - EXPECT_THAT(status, IsSupersetOf({Pair("Name", thread_name), - Pair("Tgid", absl::StrCat(tgid)), - Pair("Pid", absl::StrCat(tid)), - Pair("PPid", absl::StrCat(getppid()))})); + EXPECT_THAT(status, IsSupersetOf({ + Pair("Name", thread_name), + Pair("Tgid", absl::StrCat(tgid)), + Pair("Pid", absl::StrCat(tid)), + Pair("PPid", absl::StrCat(getppid())), + })); + + if (!IsRunningWithVFS1()) { + uid_t ruid, euid, suid; + ASSERT_THAT(getresuid(&ruid, &euid, &suid), SyscallSucceeds()); + gid_t rgid, egid, sgid; + ASSERT_THAT(getresgid(&rgid, &egid, &sgid), SyscallSucceeds()); + std::vector<gid_t> supplementary_gids; + int ngids = getgroups(0, nullptr); + supplementary_gids.resize(ngids); + ASSERT_THAT(getgroups(ngids, supplementary_gids.data()), + SyscallSucceeds()); + + EXPECT_THAT( + status, + IsSupersetOf(std::vector< + ::testing::Matcher<std::pair<std::string, std::string>>>{ + // gVisor doesn't support fsuid/gid, and even if it did there is + // no getfsuid/getfsgid(). + Pair("Uid", StartsWith(absl::StrFormat("%d\t%d\t%d\t", ruid, euid, + suid))), + Pair("Gid", StartsWith(absl::StrFormat("%d\t%d\t%d\t", rgid, egid, + sgid))), + // ParseProcStatus strips leading whitespace for each value, + // so if the Groups line is empty then the trailing space is + // stripped. + Pair("Groups", + StartsWith(absl::StrJoin(supplementary_gids, " "))), + })); + } }); } diff --git a/test/syscalls/linux/socket_netlink_route.cc b/test/syscalls/linux/socket_netlink_route.cc index a5c788346..d5e1ce0cc 100644 --- a/test/syscalls/linux/socket_netlink_route.cc +++ b/test/syscalls/linux/socket_netlink_route.cc @@ -44,6 +44,7 @@ namespace { constexpr uint32_t kSeq = 12345; +using ::testing::_; using ::testing::AnyOf; using ::testing::Eq; @@ -244,7 +245,7 @@ TEST(NetlinkRouteTest, GetLinkByIndexNotFound) { req.ifm.ifi_index = 1234590; EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), - PosixErrorIs(ENODEV, ::testing::_)); + PosixErrorIs(ENODEV, _)); } TEST(NetlinkRouteTest, GetLinkByNameNotFound) { @@ -273,7 +274,112 @@ TEST(NetlinkRouteTest, GetLinkByNameNotFound) { NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len); EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), - PosixErrorIs(ENODEV, ::testing::_)); + PosixErrorIs(ENODEV, _)); +} + +TEST(NetlinkRouteTest, RemoveLoopbackByName) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + struct rtattr rtattr; + char ifname[IFNAMSIZ]; + char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; + }; + + struct request req = {}; + req.hdr.nlmsg_type = RTM_DELLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.rtattr.rta_type = IFLA_IFNAME; + req.rtattr.rta_len = RTA_LENGTH(loopback_link.name.size() + 1); + strncpy(req.ifname, loopback_link.name.c_str(), sizeof(req.ifname)); + req.hdr.nlmsg_len = + NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len); + + EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), + PosixErrorIs(ENOTSUP, _)); +} + +TEST(NetlinkRouteTest, RemoveLoopbackByIndex) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + Link loopback_link = ASSERT_NO_ERRNO_AND_VALUE(LoopbackLink()); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_DELLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = loopback_link.index; + + EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), + PosixErrorIs(ENOTSUP, _)); +} + +TEST(NetlinkRouteTest, RemoveLinkByIndexNotFound) { + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + }; + + struct request req = {}; + req.hdr.nlmsg_len = sizeof(req); + req.hdr.nlmsg_type = RTM_GETLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.ifm.ifi_index = 1234590; + + EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), + PosixErrorIs(ENODEV, _)); +} + +TEST(NetlinkRouteTest, RemoveLinkByNameNotFound) { + const std::string name = "nodevice?!"; + + SKIP_IF(!ASSERT_NO_ERRNO_AND_VALUE(HaveCapability(CAP_NET_ADMIN))); + FileDescriptor fd = + ASSERT_NO_ERRNO_AND_VALUE(NetlinkBoundSocket(NETLINK_ROUTE)); + + struct request { + struct nlmsghdr hdr; + struct ifinfomsg ifm; + struct rtattr rtattr; + char ifname[IFNAMSIZ]; + char pad[NLMSG_ALIGNTO + RTA_ALIGNTO]; + }; + + struct request req = {}; + req.hdr.nlmsg_type = RTM_DELLINK; + req.hdr.nlmsg_flags = NLM_F_REQUEST; + req.hdr.nlmsg_seq = kSeq; + req.ifm.ifi_family = AF_UNSPEC; + req.rtattr.rta_type = IFLA_IFNAME; + req.rtattr.rta_len = RTA_LENGTH(name.size() + 1); + strncpy(req.ifname, name.c_str(), sizeof(req.ifname)); + req.hdr.nlmsg_len = + NLMSG_LENGTH(sizeof(req.ifm)) + NLMSG_ALIGN(req.rtattr.rta_len); + + EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), + PosixErrorIs(ENODEV, _)); } TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { @@ -295,7 +401,7 @@ TEST(NetlinkRouteTest, MsgHdrMsgUnsuppType) { req.ifm.ifi_family = AF_UNSPEC; EXPECT_THAT(NetlinkRequestAckOrError(fd, kSeq, &req, sizeof(req)), - PosixErrorIs(EOPNOTSUPP, ::testing::_)); + PosixErrorIs(EOPNOTSUPP, _)); } TEST(NetlinkRouteTest, MsgHdrMsgTrunc) { @@ -536,7 +642,7 @@ TEST(NetlinkRouteTest, AddAndRemoveAddr) { // Second delete should fail, as address no longer exists. EXPECT_THAT(LinkDelLocalAddr(loopback_link.index, AF_INET, /*prefixlen=*/24, &addr, sizeof(addr)), - PosixErrorIs(EADDRNOTAVAIL, ::testing::_)); + PosixErrorIs(EADDRNOTAVAIL, _)); }); // Replace an existing address should succeed. @@ -546,7 +652,7 @@ TEST(NetlinkRouteTest, AddAndRemoveAddr) { // Create exclusive should fail, as we created the address above. EXPECT_THAT(LinkAddExclusiveLocalAddr(loopback_link.index, AF_INET, /*prefixlen=*/24, &addr, sizeof(addr)), - PosixErrorIs(EEXIST, ::testing::_)); + PosixErrorIs(EEXIST, _)); } // GetRouteDump tests a RTM_GETROUTE + NLM_F_DUMP request. diff --git a/test/syscalls/linux/stat.cc b/test/syscalls/linux/stat.cc index 72f888659..19dc80d0c 100644 --- a/test/syscalls/linux/stat.cc +++ b/test/syscalls/linux/stat.cc @@ -765,7 +765,7 @@ TEST_F(StatTest, StatxSymlink) { SKIP_IF(!IsRunningOnGvisor() && statx(-1, nullptr, 0, 0, nullptr) < 0 && errno == ENOSYS); - std::string parent_dir = "/tmp"; + std::string parent_dir = GetAbsoluteTestTmpdir(); TempPath link = ASSERT_NO_ERRNO_AND_VALUE( TempPath::CreateSymlinkTo(parent_dir, test_file_name_)); std::string p = link.path(); diff --git a/test/syscalls/linux/statfs.cc b/test/syscalls/linux/statfs.cc index d4ea8e026..d057cdc09 100644 --- a/test/syscalls/linux/statfs.cc +++ b/test/syscalls/linux/statfs.cc @@ -28,7 +28,7 @@ namespace testing { namespace { TEST(StatfsTest, CannotStatBadPath) { - auto temp_file = NewTempAbsPathInDir("/tmp"); + auto temp_file = NewTempAbsPath(); struct statfs st; EXPECT_THAT(statfs(temp_file.c_str(), &st), SyscallFailsWithErrno(ENOENT)); diff --git a/tools/checklinkname/BUILD b/tools/checklinkname/BUILD new file mode 100644 index 000000000..0f1b07e24 --- /dev/null +++ b/tools/checklinkname/BUILD @@ -0,0 +1,16 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "checklinkname", + srcs = [ + "check_linkname.go", + "known.go", + ], + nogo = False, + visibility = ["//tools/nogo:__subpackages__"], + deps = [ + "@org_golang_x_tools//go/analysis:go_default_library", + ], +) diff --git a/tools/checklinkname/README.md b/tools/checklinkname/README.md new file mode 100644 index 000000000..06b3c302d --- /dev/null +++ b/tools/checklinkname/README.md @@ -0,0 +1,54 @@ +# `checklinkname` Analyzer + +`checklinkname` is an analyzer to provide rudimentary type-checking for +`//go:linkname` directives. Since `//go:linkname` only affects linker behavior, +there is no built-in type safety and it is the programmer's responsibility to +ensure the types on either side are compatible. + +`checklinkname` helps with this by checking that uses match expectations, as +defined in this package. + +`known.go` contains the set of known linkname targets. For most functions, we +expect identical types on both sides of the linkname. In a few cases, the types +may be slightly different (e.g., local redefinition of internal type). It is +still the responsibility of the programmer to ensure the signatures in +`known.go` are compatible and safe. + +## Findings + +Here are the most common findings from this package, and how to resolve them. + +### `runtime.foo signature got "BAR" want "BAZ"; stdlib type changed?` + +The definition of `runtime.foo` in the standard library does not match the +expected type in `known.go`. This means that the function signature in the +standard library changed. + +Addressing this will require creating a new linkname directive in a new Go +version build-tagged in any packages using this symbol. Be sure to also check to +ensure use with the new version is safe, as function constraints may have +changed in addition to the signature. + +<!-- TODO(b/165820485): This isn't yet explicitly supported. --> + +`known.go` will also need to be updated to accept the new signature for the new +version of Go. + +### `Cannot find known symbol "runtime.foo"` + +The standard library has removed runtime.foo entirely. Handling is similar to +above, except existing code must transition away from the symbol entirely (note +that is may simply be renamed). + +### `linkname to unknown symbol "mypkg.foo"; add this symbol to checklinkname.knownLinknames type-check against the remote type` + +A package has added a new linkname directive for a symbol not listed in +`known.go`. Address this by adding a new entry for the target symbol. The +`local` field should be the expected type in your package, while `remote` should +be expected type in the remote package (e.g., in the standard library). These +are typically identical, in which case `remote` can be omitted. + +### `usage: //go:linkname localname [linkname]` + +Malformed `//go:linkname` directive. This should be accompanied by a build +failure in the package. diff --git a/tools/checklinkname/check_linkname.go b/tools/checklinkname/check_linkname.go new file mode 100644 index 000000000..5373dd762 --- /dev/null +++ b/tools/checklinkname/check_linkname.go @@ -0,0 +1,229 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package checklinkname ensures that linkname declarations match their source. +package checklinkname + +import ( + "fmt" + "go/ast" + "go/token" + "go/types" + "strings" + + "golang.org/x/tools/go/analysis" +) + +// Analyzer implements the checklinkname analyzer. +var Analyzer = &analysis.Analyzer{ + Name: "checklinkname", + Doc: "verifies that linkname declarations match their source", + Run: run, +} + +// go:linkname can be rather confusing. https://pkg.go.dev/cmd/compile says: +// +// //go:linkname localname [importpath.name] +// +// This special directive does not apply to the Go code that follows it. +// Instead, the //go:linkname directive instructs the compiler to use +// “importpath.name” as the object file symbol name for the variable or +// function declared as “localname” in the source code. If the +// “importpath.name” argument is omitted, the directive uses the symbol's +// default object file symbol name and only has the effect of making the symbol +// accessible to other packages. Because this directive can subvert the type +// system and package modularity, it is only enabled in files that have +// imported "unsafe". +// +// In this package we use the term "local" to refer to the symbol name in the +// same package as the //go:linkname directive, whose name will be changed by +// the linker. We use the term "remote" to refer to the symbol name that we are +// changing to. +// +// In the general case, the local symbol is a function declaration, and the +// remote symbol is a real function in the standard library. + +// linknameSignatures describes a the type signatures of the symbols in a +// //go:linkname directive. +type linknameSignatures struct { + local string + remote string // equivalent to local if "". +} + +func (l *linknameSignatures) Remote() string { + if l.remote == "" { + return l.local + } + return l.remote +} + +// linknameSymbols describes the symbol namess in a single //go:linkname +// directive. +type linknameSymbols struct { + pos token.Pos + local string + remote string +} + +func findLinknames(pass *analysis.Pass, f *ast.File) []linknameSymbols { + var names []linknameSymbols + + for _, cg := range f.Comments { + for _, c := range cg.List { + if len(c.Text) <= 2 || !strings.HasPrefix(c.Text[2:], "go:linkname ") { + continue + } + + f := strings.Fields(c.Text) + if len(f) < 2 || len(f) > 3 { + // Malformed linkname. This is the same error the compiler emits. + pass.Reportf(c.Slash, "usage: //go:linkname localname [linkname]") + } + + if len(f) == 2 { + // "If the “importpath.name” argument is + // omitted, the directive uses the symbol's + // default object file symbol name and only has + // the effect of making the symbol accessible + // to other packages." + // -https://golang.org/cmd/compile + // + // There is no type-checking to be done here. + continue + } + + names = append(names, linknameSymbols{ + pos: c.Slash, + local: f[1], + remote: f[2], + }) + } + } + + return names +} + +func splitSymbol(pkg *types.Package, symbol string) (packagePath, name string) { + // Note that some runtime symbols can have multiple dots. e.g., + // runtime..init_task. + s := strings.SplitN(symbol, ".", 2) + + switch len(s) { + case 1: + // Package name omitted, use current package. + return pkg.Path(), symbol + case 2: + return s[0], s[1] + default: + panic("unreachable") + } +} + +func findObject(pkg *types.Package, symbol string) (types.Object, error) { + packagePath, symbolName := splitSymbol(pkg, symbol) + return findPackageObject(pkg, packagePath, symbolName) +} + +func findPackageObject(pkg *types.Package, packagePath, symbolName string) (types.Object, error) { + if pkg.Path() == packagePath { + o := pkg.Scope().Lookup(symbolName) + if o == nil { + return nil, fmt.Errorf("%q not found in %q (names: %+v)", symbolName, packagePath, pkg.Scope().Names()) + } + return o, nil + } + + for _, p := range pkg.Imports() { + if o, err := findPackageObject(p, packagePath, symbolName); err == nil { + return o, nil + } + } + + return nil, fmt.Errorf("package %q not found", packagePath) +} + +// checkOneLinkname verifies that the type of sym.local matches the type from +// knownLinknames. +func checkOneLinkname(pass *analysis.Pass, f *ast.File, sym linknameSymbols) { + remotePackage, remoteName := splitSymbol(pass.Pkg, sym.remote) + + m, ok := knownLinknames[remotePackage] + if !ok { + pass.Reportf(sym.pos, "linkname to unknown symbol %q; add this symbol to checklinkname.knownLinknames type-check against the remote type", sym.remote) + return + } + + linkname, ok := m[remoteName] + if !ok { + pass.Reportf(sym.pos, "linkname to unknown symbol %q; add this symbol to checklinkname.knownLinknames type-check against the remote type", sym.remote) + return + } + + local, err := findObject(pass.Pkg, sym.local) + if err != nil { + pass.Reportf(sym.pos, "Unable to find symbol %q: %v", sym.local, err) + return + } + + localSig, ok := local.Type().(*types.Signature) + if !ok { + pass.Reportf(local.Pos(), "%q object is not a signature: %+#v", sym.local, local) + return + } + + if linkname.local != localSig.String() { + pass.Reportf(local.Pos(), "%q signature got %q want %q; mismatched types?", sym.local, localSig.String(), linkname.local) + return + } +} + +// checkOneRemote verifies that the type of sym matches wantSig. +func checkOneRemote(pass *analysis.Pass, sym, wantSig string) { + o := pass.Pkg.Scope().Lookup(sym) + if o == nil { + pass.Reportf(pass.Files[0].Package, "Cannot find known symbol %q", sym) + return + } + + sig, ok := o.Type().(*types.Signature) + if !ok { + pass.Reportf(o.Pos(), "%q object is not a signature: %+#v", sym, o) + return + } + + if sig.String() != wantSig { + pass.Reportf(o.Pos(), "%q signature got %q want %q; stdlib type changed?", sym, sig.String(), wantSig) + return + } +} + +func run(pass *analysis.Pass) (interface{}, error) { + // First, check if any remote symbols are in this package. + p, ok := knownLinknames[pass.Pkg.Path()] + if ok { + for sym, l := range p { + checkOneRemote(pass, sym, l.Remote()) + } + } + + // Then check for local //go:linkname directives in this package. + for _, f := range pass.Files { + names := findLinknames(pass, f) + for _, n := range names { + checkOneLinkname(pass, f, n) + } + } + + return nil, nil +} diff --git a/tools/checklinkname/known.go b/tools/checklinkname/known.go new file mode 100644 index 000000000..54e5155fc --- /dev/null +++ b/tools/checklinkname/known.go @@ -0,0 +1,110 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package checklinkname + +// knownLinknames is the set of the symbols for which we can do a rudimentary +// type-check on. +// +// When analyzing the remote package (e.g., runtime), we verify the symbol +// signature matches 'remote'. When analyzing local packages with //go:linkname +// directives, we verify the symbol signature matches 'local'. +// +// Usually these are identical, but may differ slightly if equivalent +// replacement types are used in the local packages, such as a copy of a struct +// or uintptr instead of a pointer type. +// +// NOTE: It is the responsibility of the developer to verify the safety of the +// signatures used here! This analyzer only checks that types match this map; +// it does not verify compatibility of the entries themselves. +// +// //go:linkname directives with no corresponding entry here will trigger a +// finding. +// +// We preform only rudimentary string-based type-checking due to limitations in +// the analysis framework. Ideally, from the local package we'd lookup the +// remote symbol's types.Object and perform robust type-checking. +// Unfortunately, remote symbols are typically loaded from the remote package's +// gcexportdata. Since //go:linkname targets are usually not exported symbols, +// they are no included in gcexportdata and we cannot load their types.Object. +// +// TODO(b/165820485): Add option to specific per-version signatures. +var knownLinknames = map[string]map[string]linknameSignatures{ + "runtime": map[string]linknameSignatures{ + "entersyscall": linknameSignatures{ + local: "func()", + }, + "entersyscallblock": linknameSignatures{ + local: "func()", + }, + "exitsyscall": linknameSignatures{ + local: "func()", + }, + "fastrand": linknameSignatures{ + local: "func() uint32", + }, + "gopark": linknameSignatures{ + // TODO(b/165820485): add verification of waitReason + // size and reason and traceEv values. + local: "func(unlockf func(uintptr, unsafe.Pointer) bool, lock unsafe.Pointer, reason uint8, traceEv byte, traceskip int)", + remote: "func(unlockf func(*runtime.g, unsafe.Pointer) bool, lock unsafe.Pointer, reason runtime.waitReason, traceEv byte, traceskip int)", + }, + "goready": linknameSignatures{ + local: "func(gp uintptr, traceskip int)", + remote: "func(gp *runtime.g, traceskip int)", + }, + "goyield": linknameSignatures{ + local: "func()", + }, + "memmove": linknameSignatures{ + local: "func(to unsafe.Pointer, from unsafe.Pointer, n uintptr)", + }, + "throw": linknameSignatures{ + local: "func(s string)", + }, + }, + "sync": map[string]linknameSignatures{ + "runtime_canSpin": linknameSignatures{ + local: "func(i int) bool", + }, + "runtime_doSpin": linknameSignatures{ + local: "func()", + }, + "runtime_Semacquire": linknameSignatures{ + // The only difference here is the parameter names. We + // can't just change our local use to match remote, as + // the stdlib runtime and sync packages also disagree + // on the name, and the analyzer checks that use as + // well. + local: "func(addr *uint32)", + remote: "func(s *uint32)", + }, + "runtime_Semrelease": linknameSignatures{ + // See above. + local: "func(addr *uint32, handoff bool, skipframes int)", + remote: "func(s *uint32, handoff bool, skipframes int)", + }, + }, + "syscall": map[string]linknameSignatures{ + "runtime_BeforeFork": linknameSignatures{ + local: "func()", + }, + "runtime_AfterFork": linknameSignatures{ + local: "func()", + }, + "runtime_AfterForkInChild": linknameSignatures{ + local: "func()", + }, + }, +} diff --git a/tools/checklinkname/test/BUILD b/tools/checklinkname/test/BUILD new file mode 100644 index 000000000..b29bd84f2 --- /dev/null +++ b/tools/checklinkname/test/BUILD @@ -0,0 +1,9 @@ +load("//tools:defs.bzl", "go_library") + +package(licenses = ["notice"]) + +go_library( + name = "test", + testonly = 1, + srcs = ["test_unsafe.go"], +) diff --git a/tools/checklinkname/test/test_unsafe.go b/tools/checklinkname/test/test_unsafe.go new file mode 100644 index 000000000..a7504591c --- /dev/null +++ b/tools/checklinkname/test/test_unsafe.go @@ -0,0 +1,34 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package test provides linkname test targets. +package test + +import ( + _ "unsafe" // for go:linkname. +) + +//go:linkname DetachedLinkname runtime.fastrand + +//go:linkname attachedLinkname runtime.entersyscall +func attachedLinkname() + +// AttachedLinkname reexports attachedLinkname because go vet doesn't like an +// exported go:linkname without a comment starting with "// AttachedLinkname". +func AttachedLinkname() { + attachedLinkname() +} + +// DetachedLinkname has a linkname elsewhere in the file. +func DetachedLinkname() uint32 diff --git a/tools/constraintutil/BUILD b/tools/constraintutil/BUILD new file mode 100644 index 000000000..004b708c4 --- /dev/null +++ b/tools/constraintutil/BUILD @@ -0,0 +1,18 @@ +load("//tools:defs.bzl", "go_library", "go_test") + +package(licenses = ["notice"]) + +go_library( + name = "constraintutil", + srcs = ["constraintutil.go"], + marshal = False, + stateify = False, + visibility = ["//tools:__subpackages__"], +) + +go_test( + name = "constraintutil_test", + size = "small", + srcs = ["constraintutil_test.go"], + library = ":constraintutil", +) diff --git a/tools/constraintutil/constraintutil.go b/tools/constraintutil/constraintutil.go new file mode 100644 index 000000000..fb3fbe5c2 --- /dev/null +++ b/tools/constraintutil/constraintutil.go @@ -0,0 +1,169 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Package constraintutil provides utilities for working with Go build +// constraints. +package constraintutil + +import ( + "bufio" + "bytes" + "fmt" + "go/build/constraint" + "io" + "os" + "strings" +) + +// FromReader extracts the build constraint from the Go source or assembly file +// whose contents are read by r. +func FromReader(r io.Reader) (constraint.Expr, error) { + // See go/build.parseFileHeader() for the "official" logic that this is + // derived from. + const ( + slashStar = "/*" + starSlash = "*/" + gobuildPrefix = "//go:build" + ) + s := bufio.NewScanner(r) + var ( + inSlashStar = false // between /* and */ + haveGobuild = false + e constraint.Expr + ) +Lines: + for s.Scan() { + line := bytes.TrimSpace(s.Bytes()) + if !inSlashStar && constraint.IsGoBuild(string(line)) { + if haveGobuild { + return nil, fmt.Errorf("multiple go:build directives") + } + haveGobuild = true + var err error + e, err = constraint.Parse(string(line)) + if err != nil { + return nil, err + } + } + ThisLine: + for len(line) > 0 { + if inSlashStar { + if i := bytes.Index(line, []byte(starSlash)); i >= 0 { + inSlashStar = false + line = bytes.TrimSpace(line[i+len(starSlash):]) + continue ThisLine + } + continue Lines + } + if bytes.HasPrefix(line, []byte("//")) { + continue Lines + } + // Note that if /* appears in the line, but not at the beginning, + // then the line is still non-empty, so skipping this and + // terminating below is correct. + if bytes.HasPrefix(line, []byte(slashStar)) { + inSlashStar = true + line = bytes.TrimSpace(line[len(slashStar):]) + continue ThisLine + } + // A non-empty non-comment line terminates scanning for go:build. + break Lines + } + } + return e, s.Err() +} + +// FromString extracts the build constraint from the Go source or assembly file +// containing the given data. If no build constraint applies to the file, it +// returns nil. +func FromString(str string) (constraint.Expr, error) { + return FromReader(strings.NewReader(str)) +} + +// FromFile extracts the build constraint from the Go source or assembly file +// at the given path. If no build constraint applies to the file, it returns +// nil. +func FromFile(path string) (constraint.Expr, error) { + f, err := os.Open(path) + if err != nil { + return nil, err + } + defer f.Close() + return FromReader(f) +} + +// Combine returns a constraint.Expr that evaluates to true iff all expressions +// in es evaluate to true. If es is empty, Combine returns nil. +// +// Preconditions: All constraint.Exprs in es are non-nil. +func Combine(es []constraint.Expr) constraint.Expr { + switch len(es) { + case 0: + return nil + case 1: + return es[0] + default: + a := &constraint.AndExpr{es[0], es[1]} + for i := 2; i < len(es); i++ { + a = &constraint.AndExpr{a, es[i]} + } + return a + } +} + +// CombineFromFiles returns a build constraint expression that evaluates to +// true iff the build constraints from all of the given Go source or assembly +// files evaluate to true. If no build constraints apply to any of the given +// files, it returns nil. +func CombineFromFiles(paths []string) (constraint.Expr, error) { + var es []constraint.Expr + for _, path := range paths { + e, err := FromFile(path) + if err != nil { + return nil, fmt.Errorf("failed to read build constraints from %q: %v", path, err) + } + if e != nil { + es = append(es, e) + } + } + return Combine(es), nil +} + +// Lines returns a string containing build constraint directives for the given +// constraint.Expr, including two trailing newlines, as appropriate for a Go +// source or assembly file. At least a go:build directive will be emitted; if +// the constraint is expressible using +build directives as well, then +build +// directives will also be emitted. +// +// If e is nil, Lines returns the empty string. +func Lines(e constraint.Expr) string { + if e == nil { + return "" + } + + var b strings.Builder + b.WriteString("//go:build ") + b.WriteString(e.String()) + b.WriteByte('\n') + + if pblines, err := constraint.PlusBuildLines(e); err == nil { + for _, line := range pblines { + b.WriteString(line) + b.WriteByte('\n') + } + } + + b.WriteByte('\n') + return b.String() +} diff --git a/tools/constraintutil/constraintutil_test.go b/tools/constraintutil/constraintutil_test.go new file mode 100644 index 000000000..eeabd8dcf --- /dev/null +++ b/tools/constraintutil/constraintutil_test.go @@ -0,0 +1,138 @@ +// Copyright 2021 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package constraintutil + +import ( + "go/build/constraint" + "testing" +) + +func TestFileParsing(t *testing.T) { + for _, test := range []struct { + name string + data string + expr string + }{ + { + name: "Empty", + }, + { + name: "NoConstraint", + data: "// copyright header\n\npackage main", + }, + { + name: "ConstraintOnFirstLine", + data: "//go:build amd64\n#include \"textflag.h\"", + expr: "amd64", + }, + { + name: "ConstraintAfterSlashSlashComment", + data: "// copyright header\n\n//go:build linux\n\npackage newlib", + expr: "linux", + }, + { + name: "ConstraintAfterSlashStarComment", + data: "/*\ncopyright header\n*/\n\n//go:build !race\n\npackage oldlib", + expr: "!race", + }, + { + name: "ConstraintInSlashSlashComment", + data: "// blah blah //go:build windows", + }, + { + name: "ConstraintInSlashStarComment", + data: "/*\n//go:build windows\n*/", + }, + { + name: "ConstraintAfterPackageClause", + data: "package oops\n//go:build race", + }, + { + name: "ConstraintAfterCppInclude", + data: "#include \"textflag.h\"\n//go:build arm64", + }, + } { + t.Run(test.name, func(t *testing.T) { + e, err := FromString(test.data) + if err != nil { + t.Fatalf("FromString(%q) failed: %v", test.data, err) + } + if e == nil { + if len(test.expr) != 0 { + t.Errorf("FromString(%q): got no constraint, wanted %q", test.data, test.expr) + } + } else { + got := e.String() + if len(test.expr) == 0 { + t.Errorf("FromString(%q): got %q, wanted no constraint", test.data, got) + } else if got != test.expr { + t.Errorf("FromString(%q): got %q, wanted %q", test.data, got, test.expr) + } + } + }) + } +} + +func TestCombine(t *testing.T) { + for _, test := range []struct { + name string + in []string + out string + }{ + { + name: "0", + }, + { + name: "1", + in: []string{"amd64 || arm64"}, + out: "amd64 || arm64", + }, + { + name: "2", + in: []string{"amd64", "amd64 && linux"}, + out: "amd64 && amd64 && linux", + }, + { + name: "3", + in: []string{"amd64", "amd64 || arm64", "amd64 || riscv64"}, + out: "amd64 && (amd64 || arm64) && (amd64 || riscv64)", + }, + } { + t.Run(test.name, func(t *testing.T) { + inexprs := make([]constraint.Expr, 0, len(test.in)) + for _, estr := range test.in { + line := "//go:build " + estr + e, err := constraint.Parse(line) + if err != nil { + t.Fatalf("constraint.Parse(%q) failed: %v", line, err) + } + inexprs = append(inexprs, e) + } + outexpr := Combine(inexprs) + if outexpr == nil { + if len(test.out) != 0 { + t.Errorf("Combine(%v): got no constraint, wanted %q", test.in, test.out) + } + } else { + got := outexpr.String() + if len(test.out) == 0 { + t.Errorf("Combine(%v): got %q, wanted no constraint", test.in, got) + } else if got != test.out { + t.Errorf("Combine(%v): got %q, wanted %q", test.in, got, test.out) + } + } + }) + } +} diff --git a/tools/go_generics/go_merge/BUILD b/tools/go_generics/go_merge/BUILD index 5e0487e93..211e6b3ed 100644 --- a/tools/go_generics/go_merge/BUILD +++ b/tools/go_generics/go_merge/BUILD @@ -7,6 +7,6 @@ go_binary( srcs = ["main.go"], visibility = ["//:sandbox"], deps = [ - "//tools/tags", + "//tools/constraintutil", ], ) diff --git a/tools/go_generics/go_merge/main.go b/tools/go_generics/go_merge/main.go index 801f2354f..81394ddce 100644 --- a/tools/go_generics/go_merge/main.go +++ b/tools/go_generics/go_merge/main.go @@ -25,9 +25,8 @@ import ( "os" "path/filepath" "strconv" - "strings" - "gvisor.dev/gvisor/tools/tags" + "gvisor.dev/gvisor/tools/constraintutil" ) var ( @@ -131,6 +130,12 @@ func main() { } f.Decls = newDecls + // Infer build constraints for the output file. + bcexpr, err := constraintutil.CombineFromFiles(flag.Args()) + if err != nil { + fatalf("Failed to read build constraints: %v\n", err) + } + // Write the output file. var buf bytes.Buffer if err := format.Node(&buf, fset, f); err != nil { @@ -141,9 +146,7 @@ func main() { fatalf("opening output: %v\n", err) } defer outf.Close() - if t := tags.Aggregate(flag.Args()); len(t) > 0 { - fmt.Fprintf(outf, "%s\n\n", strings.Join(t.Lines(), "\n")) - } + outf.WriteString(constraintutil.Lines(bcexpr)) if _, err := outf.Write(buf.Bytes()); err != nil { fatalf("write: %v\n", err) } diff --git a/tools/go_marshal/gomarshal/BUILD b/tools/go_marshal/gomarshal/BUILD index c2747d94c..aaa203115 100644 --- a/tools/go_marshal/gomarshal/BUILD +++ b/tools/go_marshal/gomarshal/BUILD @@ -18,5 +18,5 @@ go_library( visibility = [ "//:sandbox", ], - deps = ["//tools/tags"], + deps = ["//tools/constraintutil"], ) diff --git a/tools/go_marshal/gomarshal/generator.go b/tools/go_marshal/gomarshal/generator.go index 00961c90d..4c23637c0 100644 --- a/tools/go_marshal/gomarshal/generator.go +++ b/tools/go_marshal/gomarshal/generator.go @@ -25,7 +25,7 @@ import ( "sort" "strings" - "gvisor.dev/gvisor/tools/tags" + "gvisor.dev/gvisor/tools/constraintutil" ) // List of identifiers we use in generated code that may conflict with a @@ -123,16 +123,18 @@ func (g *Generator) writeHeader() error { var b sourceBuffer b.emit("// Automatically generated marshal implementation. See tools/go_marshal.\n\n") - // Emit build tags. - b.emit("// If there are issues with build tag aggregation, see\n") - b.emit("// tools/go_marshal/gomarshal/generator.go:writeHeader(). The build tags here\n") - b.emit("// come from the input set of files used to generate this file. This input set\n") - b.emit("// is filtered based on pre-defined file suffixes related to build tags, see \n") - b.emit("// tools/defs.bzl:calculate_sets().\n\n") - - if t := tags.Aggregate(g.inputs); len(t) > 0 { - b.emit(strings.Join(t.Lines(), "\n")) - b.emit("\n\n") + bcexpr, err := constraintutil.CombineFromFiles(g.inputs) + if err != nil { + return err + } + if bcexpr != nil { + // Emit build constraints. + b.emit("// If there are issues with build constraint aggregation, see\n") + b.emit("// tools/go_marshal/gomarshal/generator.go:writeHeader(). The constraints here\n") + b.emit("// come from the input set of files used to generate this file. This input set\n") + b.emit("// is filtered based on pre-defined file suffixes related to build constraints,\n") + b.emit("// see tools/defs.bzl:calculate_sets().\n\n") + b.emit(constraintutil.Lines(bcexpr)) } // Package header. @@ -553,11 +555,12 @@ func (g *Generator) writeTests(ts []*testGenerator) error { b.reset() b.emit("// Automatically generated marshal tests. See tools/go_marshal.\n\n") - // Emit build tags. - if t := tags.Aggregate(g.inputs); len(t) > 0 { - b.emit(strings.Join(t.Lines(), "\n")) - b.emit("\n\n") + // Emit build constraints. + bcexpr, err := constraintutil.CombineFromFiles(g.inputs) + if err != nil { + return err } + b.emit(constraintutil.Lines(bcexpr)) b.emit("package %s\n\n", g.pkg) if err := b.write(g.outputTest); err != nil { diff --git a/tools/go_stateify/BUILD b/tools/go_stateify/BUILD index 913558b4e..ad66981c7 100644 --- a/tools/go_stateify/BUILD +++ b/tools/go_stateify/BUILD @@ -6,7 +6,7 @@ go_binary( name = "stateify", srcs = ["main.go"], visibility = ["//:sandbox"], - deps = ["//tools/tags"], + deps = ["//tools/constraintutil"], ) bzl_library( diff --git a/tools/go_stateify/main.go b/tools/go_stateify/main.go index 93022f504..7216388a0 100644 --- a/tools/go_stateify/main.go +++ b/tools/go_stateify/main.go @@ -28,7 +28,7 @@ import ( "strings" "sync" - "gvisor.dev/gvisor/tools/tags" + "gvisor.dev/gvisor/tools/constraintutil" ) var ( @@ -214,10 +214,13 @@ func main() { // Automated warning. fmt.Fprint(outputFile, "// automatically generated by stateify.\n\n") - // Emit build tags. - if t := tags.Aggregate(flag.Args()); len(t) > 0 { - fmt.Fprintf(outputFile, "%s\n\n", strings.Join(t.Lines(), "\n")) + // Emit build constraints. + bcexpr, err := constraintutil.CombineFromFiles(flag.Args()) + if err != nil { + fmt.Fprintf(os.Stderr, "Failed to infer build constraints: %v", err) + os.Exit(1) } + outputFile.WriteString(constraintutil.Lines(bcexpr)) // Emit the package name. _, pkg := filepath.Split(*fullPkg) diff --git a/tools/nogo/BUILD b/tools/nogo/BUILD index 27fe48680..d72821377 100644 --- a/tools/nogo/BUILD +++ b/tools/nogo/BUILD @@ -35,6 +35,7 @@ go_library( visibility = ["//:sandbox"], deps = [ "//tools/checkescape", + "//tools/checklinkname", "//tools/checklocks", "//tools/checkunsafe", "//tools/nogo/objdump", diff --git a/tools/nogo/analyzers.go b/tools/nogo/analyzers.go index 2b3c03fec..6705fc905 100644 --- a/tools/nogo/analyzers.go +++ b/tools/nogo/analyzers.go @@ -47,6 +47,7 @@ import ( "honnef.co/go/tools/stylecheck" "gvisor.dev/gvisor/tools/checkescape" + "gvisor.dev/gvisor/tools/checklinkname" "gvisor.dev/gvisor/tools/checklocks" "gvisor.dev/gvisor/tools/checkunsafe" ) @@ -80,6 +81,7 @@ var AllAnalyzers = []*analysis.Analyzer{ unusedresult.Analyzer, checkescape.Analyzer, checkunsafe.Analyzer, + checklinkname.Analyzer, checklocks.Analyzer, } diff --git a/tools/nogo/filter/main.go b/tools/nogo/filter/main.go index d50336b9b..4a925d03c 100644 --- a/tools/nogo/filter/main.go +++ b/tools/nogo/filter/main.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Binary check is the nogo entrypoint. +// Binary filter is the filters and reports nogo findings. package main import ( diff --git a/tools/tags/BUILD b/tools/tags/BUILD deleted file mode 100644 index 1c02e2c89..000000000 --- a/tools/tags/BUILD +++ /dev/null @@ -1,11 +0,0 @@ -load("//tools:defs.bzl", "go_library") - -package(licenses = ["notice"]) - -go_library( - name = "tags", - srcs = ["tags.go"], - marshal = False, - stateify = False, - visibility = ["//tools:__subpackages__"], -) diff --git a/tools/tags/tags.go b/tools/tags/tags.go deleted file mode 100644 index f35904e0a..000000000 --- a/tools/tags/tags.go +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2020 The gVisor Authors. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -// Package tags is a utility for parsing build tags. -package tags - -import ( - "fmt" - "io/ioutil" - "strings" -) - -// OrSet is a set of tags on a single line. -// -// Note that tags may include ",", and we don't distinguish this case in the -// logic below. Ideally, this constraints can be split into separate top-level -// build tags in order to resolve any issues. -type OrSet []string - -// Line returns the line for this or. -func (or OrSet) Line() string { - return fmt.Sprintf("// +build %s", strings.Join([]string(or), " ")) -} - -// AndSet is the set of all OrSets. -type AndSet []OrSet - -// Lines returns the lines to be printed. -func (and AndSet) Lines() (ls []string) { - for _, or := range and { - ls = append(ls, or.Line()) - } - return -} - -// Join joins this AndSet with another. -func (and AndSet) Join(other AndSet) AndSet { - return append(and, other...) -} - -// Tags returns the unique set of +build tags. -// -// Derived form the runtime's canBuild. -func Tags(file string) (tags AndSet) { - data, err := ioutil.ReadFile(file) - if err != nil { - return nil - } - // Check file contents for // +build lines. - for _, p := range strings.Split(string(data), "\n") { - p = strings.TrimSpace(p) - if p == "" { - continue - } - if !strings.HasPrefix(p, "//") { - break - } - if !strings.Contains(p, "+build") { - continue - } - fields := strings.Fields(p[2:]) - if len(fields) < 1 || fields[0] != "+build" { - continue - } - tags = append(tags, OrSet(fields[1:])) - } - return tags -} - -// Aggregate aggregates all tags from a set of files. -// -// Note that these may be in conflict, in which case the build will fail. -func Aggregate(files []string) (tags AndSet) { - for _, file := range files { - tags = tags.Join(Tags(file)) - } - return tags -} |