diff options
Diffstat (limited to 'pkg')
269 files changed, 13086 insertions, 5079 deletions
diff --git a/pkg/abi/linux/BUILD b/pkg/abi/linux/BUILD index 4a26e28de..a0654df2f 100644 --- a/pkg/abi/linux/BUILD +++ b/pkg/abi/linux/BUILD @@ -55,6 +55,8 @@ go_library( "sched.go", "seccomp.go", "sem.go", + "sem_amd64.go", + "sem_arm64.go", "shm.go", "signal.go", "signalfd.go", diff --git a/pkg/abi/linux/ioctl.go b/pkg/abi/linux/ioctl.go index 7df02dd6d..006b5a525 100644 --- a/pkg/abi/linux/ioctl.go +++ b/pkg/abi/linux/ioctl.go @@ -121,6 +121,9 @@ const ( // Constants from uapi/linux/fsverity.h. const ( + FS_VERITY_HASH_ALG_SHA256 = 1 + FS_VERITY_HASH_ALG_SHA512 = 2 + FS_IOC_ENABLE_VERITY = 1082156677 FS_IOC_MEASURE_VERITY = 3221513862 ) diff --git a/pkg/abi/linux/sem.go b/pkg/abi/linux/sem.go index 487a626cc..1b2f76c0b 100644 --- a/pkg/abi/linux/sem.go +++ b/pkg/abi/linux/sem.go @@ -34,18 +34,6 @@ const ( const SEM_UNDO = 0x1000 -// SemidDS is equivalent to struct semid64_ds. -// -// +marshal -type SemidDS struct { - SemPerm IPCPerm - SemOTime TimeT - SemCTime TimeT - SemNSems uint64 - unused3 uint64 - unused4 uint64 -} - // Sembuf is equivalent to struct sembuf. // // +marshal slice:SembufSlice diff --git a/pkg/abi/linux/sem_amd64.go b/pkg/abi/linux/sem_amd64.go new file mode 100644 index 000000000..ab980cb4f --- /dev/null +++ b/pkg/abi/linux/sem_amd64.go @@ -0,0 +1,33 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build amd64 + +package linux + +// SemidDS is equivalent to struct semid64_ds. +// +// Source: arch/x86/include/uapi/asm/sembuf.h +// +// +marshal +type SemidDS struct { + SemPerm IPCPerm + SemOTime TimeT + unused1 uint64 + SemCTime TimeT + unused2 uint64 + SemNSems uint64 + unused3 uint64 + unused4 uint64 +} diff --git a/pkg/abi/linux/sem_arm64.go b/pkg/abi/linux/sem_arm64.go new file mode 100644 index 000000000..521468fb1 --- /dev/null +++ b/pkg/abi/linux/sem_arm64.go @@ -0,0 +1,31 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// +build arm64 + +package linux + +// SemidDS is equivalent to struct semid64_ds. +// +// Source: include/uapi/asm-generic/sembuf.h +// +// +marshal +type SemidDS struct { + SemPerm IPCPerm + SemOTime TimeT + SemCTime TimeT + SemNSems uint64 + unused3 uint64 + unused4 uint64 +} diff --git a/pkg/bpf/decoder.go b/pkg/bpf/decoder.go index 069d0395d..6d1e65cb1 100644 --- a/pkg/bpf/decoder.go +++ b/pkg/bpf/decoder.go @@ -109,7 +109,7 @@ func decodeLdSize(inst linux.BPFInstruction, w *bytes.Buffer) error { case B: w.WriteString("1") default: - return fmt.Errorf("Invalid BPF LD size: %v", inst) + return fmt.Errorf("invalid BPF LD size: %v", inst) } return nil } diff --git a/pkg/context/context.go b/pkg/context/context.go index 2613bc752..f3031fc60 100644 --- a/pkg/context/context.go +++ b/pkg/context/context.go @@ -166,3 +166,27 @@ var bgContext = &logContext{Logger: log.Log()} func Background() Context { return bgContext } + +// WithValue returns a copy of parent in which the value associated with key is +// val. +func WithValue(parent Context, key, val interface{}) Context { + return &withValue{ + Context: parent, + key: key, + val: val, + } +} + +type withValue struct { + Context + key interface{} + val interface{} +} + +// Value implements Context.Value. +func (ctx *withValue) Value(key interface{}) interface{} { + if key == ctx.key { + return ctx.val + } + return ctx.Context.Value(key) +} diff --git a/pkg/merkletree/BUILD b/pkg/merkletree/BUILD index a8fcb2e19..501a9ef21 100644 --- a/pkg/merkletree/BUILD +++ b/pkg/merkletree/BUILD @@ -6,12 +6,18 @@ go_library( name = "merkletree", srcs = ["merkletree.go"], visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/usermem"], + deps = [ + "//pkg/abi/linux", + "//pkg/usermem", + ], ) go_test( name = "merkletree_test", srcs = ["merkletree_test.go"], library = ":merkletree", - deps = ["//pkg/usermem"], + deps = [ + "//pkg/abi/linux", + "//pkg/usermem", + ], ) diff --git a/pkg/merkletree/merkletree.go b/pkg/merkletree/merkletree.go index d8227b8bd..e0a9e56c5 100644 --- a/pkg/merkletree/merkletree.go +++ b/pkg/merkletree/merkletree.go @@ -18,21 +18,32 @@ package merkletree import ( "bytes" "crypto/sha256" + "crypto/sha512" "fmt" "io" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/usermem" ) const ( // sha256DigestSize specifies the digest size of a SHA256 hash. sha256DigestSize = 32 + // sha512DigestSize specifies the digest size of a SHA512 hash. + sha512DigestSize = 64 ) // DigestSize returns the size (in bytes) of a digest. -// TODO(b/156980949): Allow config other hash methods (SHA384/SHA512). -func DigestSize() int { - return sha256DigestSize +// TODO(b/156980949): Allow config SHA384. +func DigestSize(hashAlgorithm int) int { + switch hashAlgorithm { + case linux.FS_VERITY_HASH_ALG_SHA256: + return sha256DigestSize + case linux.FS_VERITY_HASH_ALG_SHA512: + return sha512DigestSize + default: + return -1 + } } // Layout defines the scale of a Merkle tree. @@ -51,11 +62,19 @@ type Layout struct { // InitLayout initializes and returns a new Layout object describing the structure // of a tree. dataSize specifies the size of input data in bytes. -func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout { +func InitLayout(dataSize int64, hashAlgorithms int, dataAndTreeInSameFile bool) (Layout, error) { layout := Layout{ blockSize: usermem.PageSize, - // TODO(b/156980949): Allow config other hash methods (SHA384/SHA512). - digestSize: sha256DigestSize, + } + + // TODO(b/156980949): Allow config SHA384. + switch hashAlgorithms { + case linux.FS_VERITY_HASH_ALG_SHA256: + layout.digestSize = sha256DigestSize + case linux.FS_VERITY_HASH_ALG_SHA512: + layout.digestSize = sha512DigestSize + default: + return Layout{}, fmt.Errorf("unexpected hash algorithms") } // treeStart is the offset (in bytes) of the first level of the tree in @@ -88,7 +107,7 @@ func InitLayout(dataSize int64, dataAndTreeInSameFile bool) Layout { } layout.levelOffset = append(layout.levelOffset, treeStart+offset*layout.blockSize) - return layout + return layout, nil } // hashesPerBlock() returns the number of digests in each block. For example, @@ -128,6 +147,7 @@ func (layout Layout) blockOffset(level int, index int64) int64 { // meatadata. type VerityDescriptor struct { Name string + FileSize int64 Mode uint32 UID uint32 GID uint32 @@ -135,16 +155,37 @@ type VerityDescriptor struct { } func (d *VerityDescriptor) String() string { - return fmt.Sprintf("Name: %s, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.Mode, d.UID, d.GID, d.RootHash) + return fmt.Sprintf("Name: %s, Size: %d, Mode: %d, UID: %d, GID: %d, RootHash: %v", d.Name, d.FileSize, d.Mode, d.UID, d.GID, d.RootHash) } // verify generates a hash from d, and compares it with expected. -func (d *VerityDescriptor) verify(expected []byte) error { - h := sha256.Sum256([]byte(d.String())) +func (d *VerityDescriptor) verify(expected []byte, hashAlgorithms int) error { + h, err := hashData([]byte(d.String()), hashAlgorithms) + if err != nil { + return err + } if !bytes.Equal(h[:], expected) { return fmt.Errorf("unexpected root hash") } return nil + +} + +// hashData hashes data and returns the result hash based on the hash +// algorithms. +func hashData(data []byte, hashAlgorithms int) ([]byte, error) { + var digest []byte + switch hashAlgorithms { + case linux.FS_VERITY_HASH_ALG_SHA256: + digestArray := sha256.Sum256(data) + digest = digestArray[:] + case linux.FS_VERITY_HASH_ALG_SHA512: + digestArray := sha512.Sum512(data) + digest = digestArray[:] + default: + return nil, fmt.Errorf("unexpected hash algorithms") + } + return digest, nil } // GenerateParams contains the parameters used to generate a Merkle tree. @@ -161,6 +202,8 @@ type GenerateParams struct { UID uint32 // GID is the group ID of the target file. GID uint32 + // HashAlgorithms is the algorithms used to hash data. + HashAlgorithms int // TreeReader is a reader for the Merkle tree. TreeReader io.ReaderAt // TreeWriter is a writer for the Merkle tree. @@ -176,7 +219,10 @@ type GenerateParams struct { // Generate returns a hash of a VerityDescriptor, which contains the file // metadata and the hash from file content. func Generate(params *GenerateParams) ([]byte, error) { - layout := InitLayout(params.Size, params.DataAndTreeInSameFile) + layout, err := InitLayout(params.Size, params.HashAlgorithms, params.DataAndTreeInSameFile) + if err != nil { + return nil, err + } numBlocks := (params.Size + layout.blockSize - 1) / layout.blockSize @@ -218,10 +264,13 @@ func Generate(params *GenerateParams) ([]byte, error) { return nil, err } // Hash the bytes in buf. - digest := sha256.Sum256(buf) + digest, err := hashData(buf, params.HashAlgorithms) + if err != nil { + return nil, err + } if level == layout.rootLevel() { - root = digest[:] + root = digest } // Write the generated hash to the end of the tree file. @@ -241,13 +290,13 @@ func Generate(params *GenerateParams) ([]byte, error) { } descriptor := VerityDescriptor{ Name: params.Name, + FileSize: params.Size, Mode: params.Mode, UID: params.UID, GID: params.GID, RootHash: root, } - ret := sha256.Sum256([]byte(descriptor.String())) - return ret[:], nil + return hashData([]byte(descriptor.String()), params.HashAlgorithms) } // VerifyParams contains the params used to verify a portion of a file against @@ -269,6 +318,8 @@ type VerifyParams struct { UID uint32 // GID is the group ID of the target file. GID uint32 + // HashAlgorithms is the algorithms used to hash data. + HashAlgorithms int // ReadOffset is the offset of the data range to be verified. ReadOffset int64 // ReadSize is the size of the data range to be verified. @@ -293,12 +344,13 @@ func verifyMetadata(params *VerifyParams, layout *Layout) error { } descriptor := VerityDescriptor{ Name: params.Name, + FileSize: params.Size, Mode: params.Mode, UID: params.UID, GID: params.GID, RootHash: root, } - return descriptor.verify(params.Expected) + return descriptor.verify(params.Expected, params.HashAlgorithms) } // Verify verifies the content read from data with offset. The content is @@ -313,7 +365,10 @@ func Verify(params *VerifyParams) (int64, error) { if params.ReadSize < 0 { return 0, fmt.Errorf("unexpected read size: %d", params.ReadSize) } - layout := InitLayout(int64(params.Size), params.DataAndTreeInSameFile) + layout, err := InitLayout(int64(params.Size), params.HashAlgorithms, params.DataAndTreeInSameFile) + if err != nil { + return 0, err + } if params.ReadSize == 0 { return 0, verifyMetadata(params, &layout) } @@ -349,12 +404,13 @@ func Verify(params *VerifyParams) (int64, error) { } } descriptor := VerityDescriptor{ - Name: params.Name, - Mode: params.Mode, - UID: params.UID, - GID: params.GID, + Name: params.Name, + FileSize: params.Size, + Mode: params.Mode, + UID: params.UID, + GID: params.GID, } - if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.Expected); err != nil { + if err := verifyBlock(params.Tree, &descriptor, &layout, buf, i, params.HashAlgorithms, params.Expected); err != nil { return 0, err } @@ -395,7 +451,7 @@ func Verify(params *VerifyParams) (int64, error) { // fails if the calculated hash from block is different from any level of // hashes stored in tree. And the final root hash is compared with // expected. -func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, expected []byte) error { +func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, dataBlock []byte, blockIndex int64, hashAlgorithms int, expected []byte) error { if len(dataBlock) != int(layout.blockSize) { return fmt.Errorf("incorrect block size") } @@ -406,8 +462,11 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, for level := 0; level < layout.numLevels(); level++ { // Calculate hash. if level == 0 { - digestArray := sha256.Sum256(dataBlock) - digest = digestArray[:] + h, err := hashData(dataBlock, hashAlgorithms) + if err != nil { + return err + } + digest = h } else { // Read a block in previous level that contains the // hash we just generated, and generate a next level @@ -415,8 +474,11 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, if _, err := tree.ReadAt(treeBlock, layout.blockOffset(level-1, blockIndex)); err != nil { return err } - digestArray := sha256.Sum256(treeBlock) - digest = digestArray[:] + h, err := hashData(treeBlock, hashAlgorithms) + if err != nil { + return err + } + digest = h } // Read the digest for the current block and store in @@ -434,5 +496,5 @@ func verifyBlock(tree io.ReaderAt, descriptor *VerityDescriptor, layout *Layout, // Verification for the tree succeeded. Now hash the descriptor with // the root hash and compare it with expected. descriptor.RootHash = digest - return descriptor.verify(expected) + return descriptor.verify(expected, hashAlgorithms) } diff --git a/pkg/merkletree/merkletree_test.go b/pkg/merkletree/merkletree_test.go index e1350ebda..405204d94 100644 --- a/pkg/merkletree/merkletree_test.go +++ b/pkg/merkletree/merkletree_test.go @@ -22,54 +22,114 @@ import ( "testing" "time" + "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/usermem" ) func TestLayout(t *testing.T) { testCases := []struct { dataSize int64 + hashAlgorithms int dataAndTreeInSameFile bool + expectedDigestSize int64 expectedLevelOffset []int64 }{ { dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, + expectedDigestSize: 32, expectedLevelOffset: []int64{0}, }, { dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedDigestSize: 64, + expectedLevelOffset: []int64{0}, + }, + { + dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + dataAndTreeInSameFile: true, + expectedDigestSize: 32, + expectedLevelOffset: []int64{usermem.PageSize}, + }, + { + dataSize: 100, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, dataAndTreeInSameFile: true, + expectedDigestSize: 64, expectedLevelOffset: []int64{usermem.PageSize}, }, { dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, + expectedDigestSize: 32, expectedLevelOffset: []int64{0, 2 * usermem.PageSize, 3 * usermem.PageSize}, }, { dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedDigestSize: 64, + expectedLevelOffset: []int64{0, 4 * usermem.PageSize, 5 * usermem.PageSize}, + }, + { + dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, + expectedDigestSize: 32, expectedLevelOffset: []int64{245 * usermem.PageSize, 247 * usermem.PageSize, 248 * usermem.PageSize}, }, { + dataSize: 1000000, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedDigestSize: 64, + expectedLevelOffset: []int64{245 * usermem.PageSize, 249 * usermem.PageSize, 250 * usermem.PageSize}, + }, + { dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: false, + expectedDigestSize: 32, expectedLevelOffset: []int64{0, 32 * usermem.PageSize, 33 * usermem.PageSize}, }, { dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: false, + expectedDigestSize: 64, + expectedLevelOffset: []int64{0, 64 * usermem.PageSize, 65 * usermem.PageSize}, + }, + { + dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, dataAndTreeInSameFile: true, + expectedDigestSize: 32, expectedLevelOffset: []int64{4096 * usermem.PageSize, 4128 * usermem.PageSize, 4129 * usermem.PageSize}, }, + { + dataSize: 4096 * int64(usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + dataAndTreeInSameFile: true, + expectedDigestSize: 64, + expectedLevelOffset: []int64{4096 * usermem.PageSize, 4160 * usermem.PageSize, 4161 * usermem.PageSize}, + }, } for _, tc := range testCases { t.Run(fmt.Sprintf("%d", tc.dataSize), func(t *testing.T) { - l := InitLayout(tc.dataSize, tc.dataAndTreeInSameFile) + l, err := InitLayout(tc.dataSize, tc.hashAlgorithms, tc.dataAndTreeInSameFile) + if err != nil { + t.Fatalf("Failed to InitLayout: %v", err) + } if l.blockSize != int64(usermem.PageSize) { t.Errorf("Got blockSize %d, want %d", l.blockSize, usermem.PageSize) } - if l.digestSize != sha256DigestSize { + if l.digestSize != tc.expectedDigestSize { t.Errorf("Got digestSize %d, want %d", l.digestSize, sha256DigestSize) } if l.numLevels() != len(tc.expectedLevelOffset) { @@ -118,24 +178,49 @@ func TestGenerate(t *testing.T) { // The input data has size dataSize. It starts with the data in startWith, // and all other bytes are zeroes. testCases := []struct { - data []byte - expectedHash []byte + data []byte + hashAlgorithms int + expectedHash []byte }{ { - data: bytes.Repeat([]byte{0}, usermem.PageSize), - expectedHash: []byte{64, 253, 58, 72, 192, 131, 82, 184, 193, 33, 108, 142, 43, 46, 179, 134, 244, 21, 29, 190, 14, 39, 66, 129, 6, 46, 200, 211, 30, 247, 191, 252}, + data: bytes.Repeat([]byte{0}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + expectedHash: []byte{39, 30, 12, 152, 185, 58, 32, 84, 218, 79, 74, 113, 104, 219, 230, 234, 25, 126, 147, 36, 212, 44, 76, 74, 25, 93, 228, 41, 243, 143, 59, 147}, + }, + { + data: bytes.Repeat([]byte{0}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + expectedHash: []byte{184, 76, 172, 204, 17, 136, 127, 75, 224, 42, 251, 181, 98, 149, 1, 44, 58, 148, 20, 187, 30, 174, 73, 87, 166, 9, 109, 169, 42, 96, 87, 202, 59, 82, 174, 80, 51, 95, 101, 100, 6, 246, 56, 120, 27, 166, 29, 59, 67, 115, 227, 121, 241, 177, 63, 238, 82, 157, 43, 107, 174, 180, 44, 84}, + }, + { + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + expectedHash: []byte{213, 221, 252, 9, 241, 250, 186, 1, 242, 132, 83, 77, 180, 207, 119, 48, 206, 113, 37, 253, 252, 159, 71, 70, 3, 53, 42, 244, 230, 244, 173, 143}, + }, + { + data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + expectedHash: []byte{40, 231, 187, 28, 3, 171, 168, 36, 177, 244, 118, 131, 218, 226, 106, 55, 245, 157, 244, 147, 144, 57, 41, 182, 65, 6, 13, 49, 38, 66, 237, 117, 124, 110, 250, 246, 248, 132, 201, 156, 195, 201, 142, 179, 122, 128, 195, 194, 187, 240, 129, 171, 168, 182, 101, 58, 194, 155, 99, 147, 49, 130, 161, 178}, }, { - data: bytes.Repeat([]byte{0}, 128*usermem.PageSize+1), - expectedHash: []byte{182, 223, 218, 62, 65, 185, 160, 219, 93, 119, 186, 88, 205, 32, 122, 231, 173, 72, 78, 76, 65, 57, 177, 146, 159, 39, 44, 123, 230, 156, 97, 26}, + data: []byte{'a'}, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + expectedHash: []byte{182, 25, 170, 240, 16, 153, 234, 4, 101, 238, 197, 154, 182, 168, 171, 96, 177, 33, 171, 117, 73, 78, 124, 239, 82, 255, 215, 121, 156, 95, 121, 171}, }, { - data: []byte{'a'}, - expectedHash: []byte{28, 201, 8, 36, 150, 178, 111, 5, 193, 212, 129, 205, 206, 124, 211, 90, 224, 142, 81, 183, 72, 165, 243, 240, 242, 241, 76, 127, 101, 61, 63, 11}, + data: []byte{'a'}, + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + expectedHash: []byte{121, 28, 140, 244, 32, 222, 61, 255, 184, 65, 117, 84, 132, 197, 122, 214, 95, 249, 164, 77, 211, 192, 217, 59, 109, 255, 249, 253, 27, 142, 110, 29, 93, 153, 92, 211, 178, 198, 136, 34, 61, 157, 141, 94, 145, 191, 201, 134, 141, 138, 51, 26, 33, 187, 17, 196, 113, 234, 125, 219, 4, 41, 57, 120}, }, { - data: bytes.Repeat([]byte{'a'}, usermem.PageSize), - expectedHash: []byte{106, 58, 160, 152, 41, 68, 38, 108, 245, 74, 177, 84, 64, 193, 19, 176, 249, 86, 27, 193, 85, 164, 99, 240, 79, 104, 148, 222, 76, 46, 191, 79}, + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA256, + expectedHash: []byte{17, 40, 99, 150, 206, 124, 196, 184, 41, 40, 50, 91, 113, 47, 8, 204, 2, 102, 202, 86, 157, 92, 218, 53, 151, 250, 234, 247, 191, 121, 113, 246}, + }, + { + data: bytes.Repeat([]byte{'a'}, usermem.PageSize), + hashAlgorithms: linux.FS_VERITY_HASH_ALG_SHA512, + expectedHash: []byte{100, 22, 249, 78, 47, 163, 220, 231, 228, 165, 226, 192, 221, 77, 106, 69, 115, 104, 208, 155, 124, 206, 225, 233, 98, 249, 232, 225, 114, 119, 110, 216, 117, 106, 85, 7, 200, 206, 139, 81, 116, 37, 215, 158, 89, 110, 74, 86, 66, 95, 117, 237, 70, 56, 62, 175, 48, 147, 162, 122, 253, 57, 123, 84}, }, } @@ -149,6 +234,7 @@ func TestGenerate(t *testing.T) { Mode: defaultMode, UID: defaultUID, GID: defaultGID, + HashAlgorithms: tc.hashAlgorithms, TreeReader: &tree, TreeWriter: &tree, DataAndTreeInSameFile: dataAndTreeInSameFile, @@ -189,6 +275,7 @@ func TestVerify(t *testing.T) { // fail, otherwise Verify should still succeed. modifyByte int64 modifyName bool + modifySize bool modifyMode bool modifyUID bool modifyGID bool @@ -237,6 +324,15 @@ func TestVerify(t *testing.T) { modifyName: true, shouldSucceed: false, }, + // Modified size should fail verification. + { + dataSize: usermem.PageSize, + verifyStart: 0, + verifySize: 0, + modifyByte: 0, + modifySize: true, + shouldSucceed: false, + }, // Modified mode should fail verification. { dataSize: usermem.PageSize, @@ -348,77 +444,84 @@ func TestVerify(t *testing.T) { // Generate random bytes in data. rand.Read(data) - for _, dataAndTreeInSameFile := range []bool{false, true} { - var tree bytesReadWriter - genParams := GenerateParams{ - Size: int64(len(data)), - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - TreeReader: &tree, - TreeWriter: &tree, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - if dataAndTreeInSameFile { - tree.Write(data) - genParams.File = &tree - } else { - genParams.File = &bytesReadWriter{ - bytes: data, + for _, hashAlgorithms := range []int{linux.FS_VERITY_HASH_ALG_SHA256, linux.FS_VERITY_HASH_ALG_SHA512} { + for _, dataAndTreeInSameFile := range []bool{false, true} { + var tree bytesReadWriter + genParams := GenerateParams{ + Size: int64(len(data)), + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + HashAlgorithms: hashAlgorithms, + TreeReader: &tree, + TreeWriter: &tree, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } + if dataAndTreeInSameFile { + tree.Write(data) + genParams.File = &tree + } else { + genParams.File = &bytesReadWriter{ + bytes: data, + } + } + hash, err := Generate(&genParams) + if err != nil { + t.Fatalf("Generate failed: %v", err) } - } - hash, err := Generate(&genParams) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - // Flip a bit in data and checks Verify results. - var buf bytes.Buffer - data[tc.modifyByte] ^= 1 - verifyParams := VerifyParams{ - Out: &buf, - File: bytes.NewReader(data), - Tree: &tree, - Size: tc.dataSize, - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - ReadOffset: tc.verifyStart, - ReadSize: tc.verifySize, - Expected: hash, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } - if tc.modifyName { - verifyParams.Name = defaultName + "abc" - } - if tc.modifyMode { - verifyParams.Mode = defaultMode + 1 - } - if tc.modifyUID { - verifyParams.UID = defaultUID + 1 - } - if tc.modifyGID { - verifyParams.GID = defaultGID + 1 - } - if tc.shouldSucceed { - n, err := Verify(&verifyParams) - if err != nil && err != io.EOF { - t.Errorf("Verification failed when expected to succeed: %v", err) + // Flip a bit in data and checks Verify results. + var buf bytes.Buffer + data[tc.modifyByte] ^= 1 + verifyParams := VerifyParams{ + Out: &buf, + File: bytes.NewReader(data), + Tree: &tree, + Size: tc.dataSize, + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + HashAlgorithms: hashAlgorithms, + ReadOffset: tc.verifyStart, + ReadSize: tc.verifySize, + Expected: hash, + DataAndTreeInSameFile: dataAndTreeInSameFile, } - if n != tc.verifySize { - t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize) + if tc.modifyName { + verifyParams.Name = defaultName + "abc" } - if int64(buf.Len()) != tc.verifySize { - t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize) + if tc.modifySize { + verifyParams.Size-- } - if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { - t.Errorf("Incorrect output buf from Verify") + if tc.modifyMode { + verifyParams.Mode = defaultMode + 1 } - } else { - if _, err := Verify(&verifyParams); err == nil { - t.Errorf("Verification succeeded when expected to fail") + if tc.modifyUID { + verifyParams.UID = defaultUID + 1 + } + if tc.modifyGID { + verifyParams.GID = defaultGID + 1 + } + if tc.shouldSucceed { + n, err := Verify(&verifyParams) + if err != nil && err != io.EOF { + t.Errorf("Verification failed when expected to succeed: %v", err) + } + if n != tc.verifySize { + t.Errorf("Got Verify output size %d, want %d", n, tc.verifySize) + } + if int64(buf.Len()) != tc.verifySize { + t.Errorf("Got Verify output buf size %d, want %d,", buf.Len(), tc.verifySize) + } + if !bytes.Equal(data[tc.verifyStart:tc.verifyStart+tc.verifySize], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") + } + } else { + if _, err := Verify(&verifyParams); err == nil { + t.Errorf("Verification succeeded when expected to fail") + } } } } @@ -435,87 +538,91 @@ func TestVerifyRandom(t *testing.T) { // Generate random bytes in data. rand.Read(data) - for _, dataAndTreeInSameFile := range []bool{false, true} { - var tree bytesReadWriter - genParams := GenerateParams{ - Size: int64(len(data)), - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - TreeReader: &tree, - TreeWriter: &tree, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } + for _, hashAlgorithms := range []int{linux.FS_VERITY_HASH_ALG_SHA256, linux.FS_VERITY_HASH_ALG_SHA512} { + for _, dataAndTreeInSameFile := range []bool{false, true} { + var tree bytesReadWriter + genParams := GenerateParams{ + Size: int64(len(data)), + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + HashAlgorithms: hashAlgorithms, + TreeReader: &tree, + TreeWriter: &tree, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } - if dataAndTreeInSameFile { - tree.Write(data) - genParams.File = &tree - } else { - genParams.File = &bytesReadWriter{ - bytes: data, + if dataAndTreeInSameFile { + tree.Write(data) + genParams.File = &tree + } else { + genParams.File = &bytesReadWriter{ + bytes: data, + } + } + hash, err := Generate(&genParams) + if err != nil { + t.Fatalf("Generate failed: %v", err) } - } - hash, err := Generate(&genParams) - if err != nil { - t.Fatalf("Generate failed: %v", err) - } - // Pick a random portion of data. - start := rand.Int63n(dataSize - 1) - size := rand.Int63n(dataSize) + 1 + // Pick a random portion of data. + start := rand.Int63n(dataSize - 1) + size := rand.Int63n(dataSize) + 1 - var buf bytes.Buffer - verifyParams := VerifyParams{ - Out: &buf, - File: bytes.NewReader(data), - Tree: &tree, - Size: dataSize, - Name: defaultName, - Mode: defaultMode, - UID: defaultUID, - GID: defaultGID, - ReadOffset: start, - ReadSize: size, - Expected: hash, - DataAndTreeInSameFile: dataAndTreeInSameFile, - } + var buf bytes.Buffer + verifyParams := VerifyParams{ + Out: &buf, + File: bytes.NewReader(data), + Tree: &tree, + Size: dataSize, + Name: defaultName, + Mode: defaultMode, + UID: defaultUID, + GID: defaultGID, + HashAlgorithms: hashAlgorithms, + ReadOffset: start, + ReadSize: size, + Expected: hash, + DataAndTreeInSameFile: dataAndTreeInSameFile, + } - // Checks that the random portion of data from the original data is - // verified successfully. - n, err := Verify(&verifyParams) - if err != nil && err != io.EOF { - t.Errorf("Verification failed for correct data: %v", err) - } - if size > dataSize-start { - size = dataSize - start - } - if n != size { - t.Errorf("Got Verify output size %d, want %d", n, size) - } - if int64(buf.Len()) != size { - t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size) - } - if !bytes.Equal(data[start:start+size], buf.Bytes()) { - t.Errorf("Incorrect output buf from Verify") - } + // Checks that the random portion of data from the original data is + // verified successfully. + n, err := Verify(&verifyParams) + if err != nil && err != io.EOF { + t.Errorf("Verification failed for correct data: %v", err) + } + if size > dataSize-start { + size = dataSize - start + } + if n != size { + t.Errorf("Got Verify output size %d, want %d", n, size) + } + if int64(buf.Len()) != size { + t.Errorf("Got Verify output buf size %d, want %d", buf.Len(), size) + } + if !bytes.Equal(data[start:start+size], buf.Bytes()) { + t.Errorf("Incorrect output buf from Verify") + } - // Verify that modified metadata should fail verification. - buf.Reset() - verifyParams.Name = defaultName + "abc" - if _, err := Verify(&verifyParams); err == nil { - t.Error("Verify succeeded for modified metadata, expect failure") - } + // Verify that modified metadata should fail verification. + buf.Reset() + verifyParams.Name = defaultName + "abc" + if _, err := Verify(&verifyParams); err == nil { + t.Error("Verify succeeded for modified metadata, expect failure") + } - // Flip a random bit in randPortion, and check that verification fails. - buf.Reset() - randBytePos := rand.Int63n(size) - data[start+randBytePos] ^= 1 - verifyParams.File = bytes.NewReader(data) - verifyParams.Name = defaultName + // Flip a random bit in randPortion, and check that verification fails. + buf.Reset() + randBytePos := rand.Int63n(size) + data[start+randBytePos] ^= 1 + verifyParams.File = bytes.NewReader(data) + verifyParams.Name = defaultName - if _, err := Verify(&verifyParams); err == nil { - t.Error("Verification succeeded for modified data, expect failure") + if _, err := Verify(&verifyParams); err == nil { + t.Error("Verification succeeded for modified data, expect failure") + } } } } diff --git a/pkg/refs/refcounter.go b/pkg/refs/refcounter.go index 699ea8ac3..6992e1de8 100644 --- a/pkg/refs/refcounter.go +++ b/pkg/refs/refcounter.go @@ -319,7 +319,8 @@ func makeStackKey(pcs []uintptr) stackKey { return key } -func recordStack() []uintptr { +// RecordStack constructs and returns the PCs on the current stack. +func RecordStack() []uintptr { pcs := make([]uintptr, maxStackFrames) n := runtime.Callers(1, pcs) if n == 0 { @@ -342,7 +343,8 @@ func recordStack() []uintptr { return v } -func formatStack(pcs []uintptr) string { +// FormatStack converts the given stack into a readable format. +func FormatStack(pcs []uintptr) string { frames := runtime.CallersFrames(pcs) var trace bytes.Buffer for { @@ -367,7 +369,7 @@ func (r *AtomicRefCount) finalize() { if n := r.ReadRefs(); n != 0 { msg := fmt.Sprintf("%sAtomicRefCount %p owned by %q garbage collected with ref count of %d (want 0)", note, r, r.name, n) if len(r.stack) != 0 { - msg += ":\nCaller:\n" + formatStack(r.stack) + msg += ":\nCaller:\n" + FormatStack(r.stack) } else { msg += " (enable trace logging to debug)" } @@ -392,7 +394,7 @@ func (r *AtomicRefCount) EnableLeakCheck(name string) { case NoLeakChecking: return case LeaksLogTraces: - r.stack = recordStack() + r.stack = RecordStack() } r.name = name runtime.SetFinalizer(r, (*AtomicRefCount).finalize) diff --git a/pkg/refs_vfs2/BUILD b/pkg/refsvfs2/BUILD index 577b827a5..bfa1daa10 100644 --- a/pkg/refs_vfs2/BUILD +++ b/pkg/refsvfs2/BUILD @@ -8,6 +8,9 @@ go_template( srcs = [ "refs_template.go", ], + opt_consts = [ + "logTrace", + ], types = [ "T", ], @@ -19,8 +22,16 @@ go_template( ) go_library( - name = "refs_vfs2", - srcs = ["refs.go"], - visibility = ["//pkg/sentry:internal"], - deps = ["//pkg/context"], + name = "refsvfs2", + srcs = [ + "refs.go", + "refs_map.go", + ], + visibility = ["//:sandbox"], + deps = [ + "//pkg/context", + "//pkg/log", + "//pkg/refs", + "//pkg/sync", + ], ) diff --git a/pkg/refs_vfs2/refs.go b/pkg/refsvfs2/refs.go index 99a074e96..ef8beb659 100644 --- a/pkg/refs_vfs2/refs.go +++ b/pkg/refsvfs2/refs.go @@ -12,8 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package refs_vfs2 defines an interface for a reference-counted object. -package refs_vfs2 +// Package refsvfs2 defines an interface for a reference-counted object. +package refsvfs2 import ( "gvisor.dev/gvisor/pkg/context" diff --git a/pkg/refsvfs2/refs_map.go b/pkg/refsvfs2/refs_map.go new file mode 100644 index 000000000..9fbc5466f --- /dev/null +++ b/pkg/refsvfs2/refs_map.go @@ -0,0 +1,131 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package refsvfs2 + +import ( + "fmt" + + "gvisor.dev/gvisor/pkg/log" + refs_vfs1 "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/sync" +) + +var ( + // liveObjects is a global map of reference-counted objects. Objects are + // inserted when leak check is enabled, and they are removed when they are + // destroyed. It is protected by liveObjectsMu. + liveObjects map[CheckedObject]struct{} + liveObjectsMu sync.Mutex +) + +// CheckedObject represents a reference-counted object with an informative +// leak detection message. +type CheckedObject interface { + // RefType is the type of the reference-counted object. + RefType() string + + // LeakMessage supplies a warning to be printed upon leak detection. + LeakMessage() string + + // LogRefs indicates whether reference-related events should be logged. + LogRefs() bool +} + +func init() { + liveObjects = make(map[CheckedObject]struct{}) +} + +// leakCheckEnabled returns whether leak checking is enabled. The following +// functions should only be called if it returns true. +func leakCheckEnabled() bool { + return refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking +} + +// Register adds obj to the live object map. +func Register(obj CheckedObject) { + if leakCheckEnabled() { + liveObjectsMu.Lock() + if _, ok := liveObjects[obj]; ok { + panic(fmt.Sprintf("Unexpected entry in leak checking map: reference %p already added", obj)) + } + liveObjects[obj] = struct{}{} + liveObjectsMu.Unlock() + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, "registered") + } + } +} + +// Unregister removes obj from the live object map. +func Unregister(obj CheckedObject) { + if leakCheckEnabled() { + liveObjectsMu.Lock() + defer liveObjectsMu.Unlock() + if _, ok := liveObjects[obj]; !ok { + panic(fmt.Sprintf("Expected to find entry in leak checking map for reference %p", obj)) + } + delete(liveObjects, obj) + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, "unregistered") + } + } +} + +// LogIncRef logs a reference increment. +func LogIncRef(obj CheckedObject, refs int64) { + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, fmt.Sprintf("IncRef to %d", refs)) + } +} + +// LogTryIncRef logs a successful TryIncRef call. +func LogTryIncRef(obj CheckedObject, refs int64) { + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, fmt.Sprintf("TryIncRef to %d", refs)) + } +} + +// LogDecRef logs a reference decrement. +func LogDecRef(obj CheckedObject, refs int64) { + if leakCheckEnabled() && obj.LogRefs() { + logEvent(obj, fmt.Sprintf("DecRef to %d", refs)) + } +} + +// logEvent logs a message for the given reference-counted object. +// +// obj.LogRefs() should be checked before calling logEvent, in order to avoid +// calling any text processing needed to evaluate msg. +func logEvent(obj CheckedObject, msg string) { + log.Infof("[%s %p] %s:", obj.RefType(), obj, msg) + log.Infof(refs_vfs1.FormatStack(refs_vfs1.RecordStack())) +} + +// DoLeakCheck iterates through the live object map and logs a message for each +// object. It is called once no reference-counted objects should be reachable +// anymore, at which point anything left in the map is considered a leak. +func DoLeakCheck() { + if leakCheckEnabled() { + liveObjectsMu.Lock() + defer liveObjectsMu.Unlock() + leaked := len(liveObjects) + if leaked > 0 { + log.Warningf("Leak checking detected %d leaked objects:", leaked) + for obj := range liveObjects { + log.Warningf(obj.LeakMessage()) + } + } + } +} diff --git a/pkg/refs_vfs2/refs_template.go b/pkg/refsvfs2/refs_template.go index d9b552896..8f50b4ee6 100644 --- a/pkg/refs_vfs2/refs_template.go +++ b/pkg/refsvfs2/refs_template.go @@ -21,20 +21,24 @@ package refs_template import ( "fmt" - "runtime" "sync/atomic" - "gvisor.dev/gvisor/pkg/log" - refs_vfs1 "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/refsvfs2" ) +// enableLogging indicates whether reference-related events should be logged (with +// stack traces). This is false by default and should only be set to true for +// debugging purposes, as it can generate an extremely large amount of output +// and drastically degrade performance. +const enableLogging = false + // T is the type of the reference counted object. It is only used to customize // debug output when leak checking. type T interface{} -// ownerType is used to customize logging. Note that we use a pointer to T so -// that we do not copy the entire object when passed as a format parameter. -var ownerType *T +// obj is used to customize logging. Note that we use a pointer to T so that +// we do not copy the entire object when passed as a format parameter. +var obj *T // Refs implements refs.RefCounter. It keeps a reference count using atomic // operations and calls the destructor when the count reaches zero. @@ -42,11 +46,6 @@ var ownerType *T // Note that the number of references is actually refCount + 1 so that a default // zero-value Refs object contains one reference. // -// TODO(gvisor.dev/issue/1486): Store stack traces when leak check is enabled in -// a map with 16-bit hashes, and store the hash in the top 16 bits of refCount. -// This will allow us to add stack trace information to the leak messages -// without growing the size of Refs. -// // +stateify savable type Refs struct { // refCount is composed of two fields: @@ -59,24 +58,24 @@ type Refs struct { refCount int64 } -func (r *Refs) finalize() { - var note string - switch refs_vfs1.GetLeakMode() { - case refs_vfs1.NoLeakChecking: - return - case refs_vfs1.UninitializedLeakChecking: - note = "(Leak checker uninitialized): " - } - if n := r.ReadRefs(); n != 0 { - log.Warningf("%sRefs %p owned by %T garbage collected with ref count of %d (want 0)", note, r, ownerType, n) - } +// RefType implements refsvfs2.CheckedObject.RefType. +func (r *Refs) RefType() string { + return fmt.Sprintf("%T", obj)[1:] +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (r *Refs) LeakMessage() string { + return fmt.Sprintf("[%s %p] reference count of %d instead of 0", r.RefType(), r, r.ReadRefs()) } -// EnableLeakCheck checks for reference leaks when Refs gets garbage collected. +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +func (r *Refs) LogRefs() bool { + return enableLogging +} + +// EnableLeakCheck enables reference leak checking on r. func (r *Refs) EnableLeakCheck() { - if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking { - runtime.SetFinalizer(r, (*Refs).finalize) - } + refsvfs2.Register(r) } // ReadRefs returns the current number of references. The returned count is @@ -90,8 +89,10 @@ func (r *Refs) ReadRefs() int64 { // //go:nosplit func (r *Refs) IncRef() { - if v := atomic.AddInt64(&r.refCount, 1); v <= 0 { - panic(fmt.Sprintf("Incrementing non-positive ref count %p owned by %T", r, ownerType)) + v := atomic.AddInt64(&r.refCount, 1) + refsvfs2.LogIncRef(r, v+1) + if v <= 0 { + panic(fmt.Sprintf("Incrementing non-positive count %p on %s", r, r.RefType())) } } @@ -104,15 +105,15 @@ func (r *Refs) IncRef() { //go:nosplit func (r *Refs) TryIncRef() bool { const speculativeRef = 1 << 32 - v := atomic.AddInt64(&r.refCount, speculativeRef) - if int32(v) < 0 { + if v := atomic.AddInt64(&r.refCount, speculativeRef); int32(v) < 0 { // This object has already been freed. atomic.AddInt64(&r.refCount, -speculativeRef) return false } // Turn into a real reference. - atomic.AddInt64(&r.refCount, -speculativeRef+1) + v := atomic.AddInt64(&r.refCount, -speculativeRef+1) + refsvfs2.LogTryIncRef(r, v+1) return true } @@ -129,14 +130,23 @@ func (r *Refs) TryIncRef() bool { // //go:nosplit func (r *Refs) DecRef(destroy func()) { - switch v := atomic.AddInt64(&r.refCount, -1); { + v := atomic.AddInt64(&r.refCount, -1) + refsvfs2.LogDecRef(r, v+1) + switch { case v < -1: - panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %T", r, ownerType)) + panic(fmt.Sprintf("Decrementing non-positive ref count %p, owned by %s", r, r.RefType())) case v == -1: + refsvfs2.Unregister(r) // Call the destructor. if destroy != nil { destroy() } } } + +func (r *Refs) afterLoad() { + if r.ReadRefs() > 0 { + r.EnableLeakCheck() + } +} diff --git a/pkg/sentry/control/state.go b/pkg/sentry/control/state.go index 41feeffe3..d800f2c85 100644 --- a/pkg/sentry/control/state.go +++ b/pkg/sentry/control/state.go @@ -69,5 +69,5 @@ func (s *State) Save(o *SaveOpts, _ *struct{}) error { s.Kernel.Kill(kernel.ExitStatus{}) }, } - return saveOpts.Save(s.Kernel, s.Watchdog) + return saveOpts.Save(s.Kernel.SupervisorContext(), s.Kernel, s.Watchdog) } diff --git a/pkg/sentry/devices/tundev/tundev.go b/pkg/sentry/devices/tundev/tundev.go index 655ea549b..ff5d49fbd 100644 --- a/pkg/sentry/devices/tundev/tundev.go +++ b/pkg/sentry/devices/tundev/tundev.go @@ -39,6 +39,8 @@ const ( ) // tunDevice implements vfs.Device for /dev/net/tun. +// +// +stateify savable type tunDevice struct{} // Open implements vfs.Device.Open. @@ -53,6 +55,8 @@ func (tunDevice) Open(ctx context.Context, mnt *vfs.Mount, vfsd *vfs.Dentry, opt } // tunFD implements vfs.FileDescriptionImpl for /dev/net/tun. +// +// +stateify savable type tunFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl diff --git a/pkg/sentry/fs/fsutil/host_file_mapper.go b/pkg/sentry/fs/fsutil/host_file_mapper.go index 1390a9a7f..4468f5dd2 100644 --- a/pkg/sentry/fs/fsutil/host_file_mapper.go +++ b/pkg/sentry/fs/fsutil/host_file_mapper.go @@ -70,6 +70,13 @@ func (f *HostFileMapper) Init() { f.mappings = make(map[uint64]mapping) } +// IsInited returns true if f.Init() has been called. This is used when +// restoring a checkpoint that contains a HostFileMapper that may or may not +// have been initialized. +func (f *HostFileMapper) IsInited() bool { + return f.refs != nil +} + // NewHostFileMapper returns an initialized HostFileMapper allocated on the // heap with no references or cached mappings. func NewHostFileMapper() *HostFileMapper { diff --git a/pkg/sentry/fs/gofer/path.go b/pkg/sentry/fs/gofer/path.go index 3c66dc3c2..6b3627813 100644 --- a/pkg/sentry/fs/gofer/path.go +++ b/pkg/sentry/fs/gofer/path.go @@ -25,7 +25,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // maxFilenameLen is the maximum length of a filename. This is dictated by 9P's @@ -305,7 +304,7 @@ func (i *inodeOperations) createInternalFifo(ctx context.Context, dir *fs.Inode, } // First create a pipe. - p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize) + p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize) // Wrap the fileOps with our Fifo. iops := &fifo{ diff --git a/pkg/sentry/fs/proc/sys_net.go b/pkg/sentry/fs/proc/sys_net.go index e555672ad..52061175f 100644 --- a/pkg/sentry/fs/proc/sys_net.go +++ b/pkg/sentry/fs/proc/sys_net.go @@ -86,9 +86,9 @@ func (*tcpMemInode) Truncate(context.Context, *fs.Inode, int64) error { } // GetFile implements fs.InodeOperations.GetFile. -func (m *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { +func (t *tcpMemInode) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { flags.Pread = true - return fs.NewFile(ctx, dirent, flags, &tcpMemFile{tcpMemInode: m}), nil + return fs.NewFile(ctx, dirent, flags, &tcpMemFile{tcpMemInode: t}), nil } // +stateify savable diff --git a/pkg/sentry/fs/proc/task.go b/pkg/sentry/fs/proc/task.go index 22d658acf..450044c9c 100644 --- a/pkg/sentry/fs/proc/task.go +++ b/pkg/sentry/fs/proc/task.go @@ -92,6 +92,7 @@ func (p *proc) newTaskDir(t *kernel.Task, msrc *fs.MountSource, isThreadGroup bo "gid_map": newGIDMap(t, msrc), "io": newIO(t, msrc, isThreadGroup), "maps": newMaps(t, msrc), + "mem": newMem(t, msrc), "mountinfo": seqfile.NewSeqFileInode(t, &mountInfoFile{t: t}, msrc), "mounts": seqfile.NewSeqFileInode(t, &mountsFile{t: t}, msrc), "net": newNetDir(t, msrc), @@ -399,6 +400,88 @@ func newNamespaceDir(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { return newProcInode(t, d, msrc, fs.SpecialDirectory, t) } +// memData implements fs.Inode for /proc/[pid]/mem. +// +// +stateify savable +type memData struct { + fsutil.SimpleFileInode + + t *kernel.Task +} + +// memDataFile implements fs.FileOperations for /proc/[pid]/mem. +// +// +stateify savable +type memDataFile struct { + fsutil.FileGenericSeek `state:"nosave"` + fsutil.FileNoIoctl `state:"nosave"` + fsutil.FileNoMMap `state:"nosave"` + fsutil.FileNoWrite `state:"nosave"` + fsutil.FileNoSplice `state:"nosave"` + fsutil.FileNoopFlush `state:"nosave"` + fsutil.FileNoopFsync `state:"nosave"` + fsutil.FileNoopRelease `state:"nosave"` + fsutil.FileNotDirReaddir `state:"nosave"` + fsutil.FileUseInodeUnstableAttr `state:"nosave"` + waiter.AlwaysReady `state:"nosave"` + + t *kernel.Task +} + +func newMem(t *kernel.Task, msrc *fs.MountSource) *fs.Inode { + inode := &memData{ + SimpleFileInode: *fsutil.NewSimpleFileInode(t, fs.RootOwner, fs.FilePermsFromMode(0400), linux.PROC_SUPER_MAGIC), + t: t, + } + return newProcInode(t, inode, msrc, fs.SpecialFile, t) +} + +// Truncate implements fs.InodeOperations.Truncate. +func (m *memData) Truncate(context.Context, *fs.Inode, int64) error { + return nil +} + +// GetFile implements fs.InodeOperations.GetFile. +func (m *memData) GetFile(ctx context.Context, dirent *fs.Dirent, flags fs.FileFlags) (*fs.File, error) { + // TODO(gvisor.dev/issue/260): Add check for PTRACE_MODE_ATTACH_FSCREDS + // Permission to read this file is governed by PTRACE_MODE_ATTACH_FSCREDS + // Since we dont implement setfsuid/setfsgid we can just use PTRACE_MODE_ATTACH + if !kernel.ContextCanTrace(ctx, m.t, true) { + return nil, syserror.EACCES + } + if err := checkTaskState(m.t); err != nil { + return nil, err + } + // Enable random access reads + flags.Pread = true + return fs.NewFile(ctx, dirent, flags, &memDataFile{t: m.t}), nil +} + +// Read implements fs.FileOperations.Read. +func (m *memDataFile) Read(ctx context.Context, _ *fs.File, dst usermem.IOSequence, offset int64) (int64, error) { + if dst.NumBytes() == 0 { + return 0, nil + } + mm, err := getTaskMM(m.t) + if err != nil { + return 0, nil + } + defer mm.DecUsers(ctx) + // Buffer the read data because of MM locks + buf := make([]byte, dst.NumBytes()) + n, readErr := mm.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true}) + if n > 0 { + if _, err := dst.CopyOut(ctx, buf[:n]); err != nil { + return 0, syserror.EFAULT + } + return int64(n), nil + } + if readErr != nil { + return 0, syserror.EIO + } + return 0, nil +} + // mapsData implements seqfile.SeqSource for /proc/[pid]/maps. // // +stateify savable diff --git a/pkg/sentry/fs/tmpfs/inode_file.go b/pkg/sentry/fs/tmpfs/inode_file.go index fc0498f17..d6c65301c 100644 --- a/pkg/sentry/fs/tmpfs/inode_file.go +++ b/pkg/sentry/fs/tmpfs/inode_file.go @@ -431,9 +431,6 @@ func (rw *fileReadWriter) ReadToBlocks(dsts safemem.BlockSeq) (uint64, error) { // Continue. seg, gap = gap.NextSegment(), fsutil.FileRangeGapIterator{} - - default: - break } } return done, nil @@ -532,9 +529,6 @@ func (rw *fileReadWriter) WriteFromBlocks(srcs safemem.BlockSeq) (uint64, error) // Write to that memory as usual. seg, gap = rw.f.data.Insert(gap, gapMR, fr.Start), fsutil.FileRangeGapIterator{} - - default: - break } } return done, nil diff --git a/pkg/sentry/fs/tmpfs/tmpfs.go b/pkg/sentry/fs/tmpfs/tmpfs.go index 998b697ca..cf4ed5de0 100644 --- a/pkg/sentry/fs/tmpfs/tmpfs.go +++ b/pkg/sentry/fs/tmpfs/tmpfs.go @@ -336,7 +336,7 @@ type Fifo struct { // NewFifo creates a new named pipe. func NewFifo(ctx context.Context, owner fs.FileOwner, perms fs.FilePermissions, msrc *fs.MountSource) *fs.Inode { // First create a pipe. - p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize) + p := pipe.NewPipe(true /* isNamed */, pipe.DefaultPipeSize) // Build pipe InodeOperations. iops := pipe.NewInodeOperations(ctx, perms, p) diff --git a/pkg/sentry/fsimpl/devpts/BUILD b/pkg/sentry/fsimpl/devpts/BUILD index 84baaac66..6af3c3781 100644 --- a/pkg/sentry/fsimpl/devpts/BUILD +++ b/pkg/sentry/fsimpl/devpts/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "root_inode_refs.go", package = "devpts", prefix = "rootInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "rootInode", }, @@ -33,6 +33,7 @@ go_library( "//pkg/marshal", "//pkg/marshal/primitive", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/fsimpl/devpts/devpts.go b/pkg/sentry/fsimpl/devpts/devpts.go index d5c5aaa8c..346cca558 100644 --- a/pkg/sentry/fsimpl/devpts/devpts.go +++ b/pkg/sentry/fsimpl/devpts/devpts.go @@ -60,7 +60,7 @@ func (fstype *FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Vir } fstype.initOnce.Do(func() { - fs, root, err := fstype.newFilesystem(vfsObj, creds) + fs, root, err := fstype.newFilesystem(ctx, vfsObj, creds) if err != nil { fstype.initErr = err return @@ -93,7 +93,7 @@ type filesystem struct { // newFilesystem creates a new devpts filesystem with root directory and ptmx // master inode. It returns the filesystem and root Dentry. -func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) { +func (fstype *FilesystemType) newFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials) (*filesystem, *kernfs.Dentry, error) { devMinor, err := vfsObj.GetAnonBlockDevMinor() if err != nil { return nil, nil, err @@ -108,19 +108,19 @@ func (fstype *FilesystemType) newFilesystem(vfsObj *vfs.VirtualFilesystem, creds root := &rootInode{ replicas: make(map[uint32]*replicaInode), } - root.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555) + root.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 1, linux.ModeDirectory|0555) root.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) root.EnableLeakCheck() var rootD kernfs.Dentry - rootD.Init(&fs.Filesystem, root) + rootD.InitRoot(&fs.Filesystem, root) // Construct the pts master inode and dentry. Linux always uses inode // id 2 for ptmx. See fs/devpts/inode.c:mknod_ptmx. master := &masterInode{ root: root, } - master.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666) + master.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, devMinor, 2, linux.ModeCharacterDevice|0666) // Add the master as a child of the root. links := root.OrderedChildren.Populate(map[string]kernfs.Inode{ @@ -170,7 +170,7 @@ type rootInode struct { var _ kernfs.Inode = (*rootInode)(nil) // allocateTerminal creates a new Terminal and installs a pts node for it. -func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) { +func (i *rootInode) allocateTerminal(ctx context.Context, creds *auth.Credentials) (*Terminal, error) { i.mu.Lock() defer i.mu.Unlock() if i.nextIdx == math.MaxUint32 { @@ -192,7 +192,7 @@ func (i *rootInode) allocateTerminal(creds *auth.Credentials) (*Terminal, error) } // Linux always uses pty index + 3 as the inode id. See // fs/devpts/inode.c:devpts_pty_new(). - replica.InodeAttrs.Init(creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) + replica.InodeAttrs.Init(ctx, creds, i.InodeAttrs.DevMajor(), i.InodeAttrs.DevMinor(), uint64(idx+3), linux.ModeCharacterDevice|0600) i.replicas[idx] = replica return t, nil @@ -248,9 +248,10 @@ func (i *rootInode) Lookup(ctx context.Context, name string) (kernfs.Inode, erro } // IterDirents implements kernfs.Inode.IterDirents. -func (i *rootInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (i *rootInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { i.mu.Lock() defer i.mu.Unlock() + i.InodeAttrs.TouchAtime(ctx, mnt) ids := make([]int, 0, len(i.replicas)) for id := range i.replicas { ids = append(ids, int(id)) diff --git a/pkg/sentry/fsimpl/devpts/line_discipline.go b/pkg/sentry/fsimpl/devpts/line_discipline.go index e6b0e81cf..ae95fdd08 100644 --- a/pkg/sentry/fsimpl/devpts/line_discipline.go +++ b/pkg/sentry/fsimpl/devpts/line_discipline.go @@ -100,10 +100,10 @@ type lineDiscipline struct { column int // masterWaiter is used to wait on the master end of the TTY. - masterWaiter waiter.Queue `state:"zerovalue"` + masterWaiter waiter.Queue // replicaWaiter is used to wait on the replica end of the TTY. - replicaWaiter waiter.Queue `state:"zerovalue"` + replicaWaiter waiter.Queue } func newLineDiscipline(termios linux.KernelTermios) *lineDiscipline { diff --git a/pkg/sentry/fsimpl/devpts/master.go b/pkg/sentry/fsimpl/devpts/master.go index fda30fb93..e91fa26a4 100644 --- a/pkg/sentry/fsimpl/devpts/master.go +++ b/pkg/sentry/fsimpl/devpts/master.go @@ -50,7 +50,7 @@ var _ kernfs.Inode = (*masterInode)(nil) // Open implements kernfs.Inode.Open. func (mi *masterInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { - t, err := mi.root.allocateTerminal(rp.Credentials()) + t, err := mi.root.allocateTerminal(ctx, rp.Credentials()) if err != nil { return nil, err } diff --git a/pkg/sentry/fsimpl/devtmpfs/BUILD b/pkg/sentry/fsimpl/devtmpfs/BUILD index 01bbee5ad..e49a04c1b 100644 --- a/pkg/sentry/fsimpl/devtmpfs/BUILD +++ b/pkg/sentry/fsimpl/devtmpfs/BUILD @@ -4,7 +4,10 @@ licenses(["notice"]) go_library( name = "devtmpfs", - srcs = ["devtmpfs.go"], + srcs = [ + "devtmpfs.go", + "save_restore.go", + ], visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", diff --git a/pkg/sentry/fsimpl/devtmpfs/save_restore.go b/pkg/sentry/fsimpl/devtmpfs/save_restore.go new file mode 100644 index 000000000..28832d850 --- /dev/null +++ b/pkg/sentry/fsimpl/devtmpfs/save_restore.go @@ -0,0 +1,23 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package devtmpfs + +// afterLoad is invoked by stateify. +func (fst *FilesystemType) afterLoad() { + if fst.fs != nil { + // Ensure that we don't create another filesystem. + fst.initOnce.Do(func() {}) + } +} diff --git a/pkg/sentry/fsimpl/eventfd/eventfd.go b/pkg/sentry/fsimpl/eventfd/eventfd.go index 1c27ad700..5b29f2358 100644 --- a/pkg/sentry/fsimpl/eventfd/eventfd.go +++ b/pkg/sentry/fsimpl/eventfd/eventfd.go @@ -43,7 +43,7 @@ type EventFileDescription struct { // queue is used to notify interested parties when the event object // becomes readable or writable. - queue waiter.Queue `state:"zerovalue"` + queue waiter.Queue // mu protects the fields below. mu sync.Mutex `state:"nosave"` diff --git a/pkg/sentry/fsimpl/fuse/BUILD b/pkg/sentry/fsimpl/fuse/BUILD index 045d7ab08..2158b1bbc 100644 --- a/pkg/sentry/fsimpl/fuse/BUILD +++ b/pkg/sentry/fsimpl/fuse/BUILD @@ -20,7 +20,7 @@ go_template_instance( out = "inode_refs.go", package = "fuse", prefix = "inode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "inode", }, @@ -49,6 +49,7 @@ go_library( "//pkg/log", "//pkg/marshal", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/fsimpl/devtmpfs", "//pkg/sentry/fsimpl/kernfs", diff --git a/pkg/sentry/fsimpl/fuse/dev_test.go b/pkg/sentry/fsimpl/fuse/dev_test.go index 5986133e9..95c475a65 100644 --- a/pkg/sentry/fsimpl/fuse/dev_test.go +++ b/pkg/sentry/fsimpl/fuse/dev_test.go @@ -315,7 +315,7 @@ func fuseServerRun(t *testing.T, s *testutil.System, k *kernel.Kernel, fd *vfs.F readPayload.MarshalUnsafe(outBuf[outHdrLen:]) outIOseq := usermem.BytesIOSequence(outBuf) - n, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{}) + _, err = fd.Write(s.Ctx, outIOseq, vfs.WriteOptions{}) if err != nil { t.Fatalf("Write failed :%v", err) } diff --git a/pkg/sentry/fsimpl/fuse/fusefs.go b/pkg/sentry/fsimpl/fuse/fusefs.go index e39df21c6..6de416da0 100644 --- a/pkg/sentry/fsimpl/fuse/fusefs.go +++ b/pkg/sentry/fsimpl/fuse/fusefs.go @@ -205,7 +205,7 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // root is the fusefs root directory. - root := fs.newRootInode(creds, fsopts.rootMode) + root := fs.newRoot(ctx, creds, fsopts.rootMode) return fs.VFSFilesystem(), root.VFSDentry(), nil } @@ -284,21 +284,21 @@ type inode struct { link string } -func (fs *filesystem) newRootInode(creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { +func (fs *filesystem) newRoot(ctx context.Context, creds *auth.Credentials, mode linux.FileMode) *kernfs.Dentry { i := &inode{fs: fs, nodeID: 1} - i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755) + i.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, 1, linux.ModeDirectory|0755) i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) i.EnableLeakCheck() var d kernfs.Dentry - d.Init(&fs.Filesystem, i) + d.InitRoot(&fs.Filesystem, i) return &d } -func (fs *filesystem) newInode(nodeID uint64, attr linux.FUSEAttr) kernfs.Inode { +func (fs *filesystem) newInode(ctx context.Context, nodeID uint64, attr linux.FUSEAttr) kernfs.Inode { i := &inode{fs: fs, nodeID: nodeID} creds := auth.Credentials{EffectiveKGID: auth.KGID(attr.UID), EffectiveKUID: auth.KUID(attr.UID)} - i.InodeAttrs.Init(&creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode)) + i.InodeAttrs.Init(ctx, &creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.FileMode(attr.Mode)) atomic.StoreUint64(&i.size, attr.Size) i.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) i.EnableLeakCheck() @@ -424,7 +424,7 @@ func (i *inode) Keep() bool { } // IterDirents implements kernfs.Inode.IterDirents. -func (*inode) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (*inode) IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { return offset, nil } @@ -544,7 +544,7 @@ func (i *inode) newEntry(ctx context.Context, name string, fileType linux.FileMo if opcode != linux.FUSE_LOOKUP && ((out.Attr.Mode&linux.S_IFMT)^uint32(fileType) != 0 || out.NodeID == 0 || out.NodeID == linux.FUSE_ROOT_ID) { return nil, syserror.EIO } - child := i.fs.newInode(out.NodeID, out.Attr) + child := i.fs.newInode(ctx, out.NodeID, out.Attr) return child, nil } @@ -696,7 +696,7 @@ func (i *inode) getAttr(ctx context.Context, fs *vfs.Filesystem, opts vfs.StatOp } // Set the metadata of kernfs.InodeAttrs. - if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{ + if err := i.InodeAttrs.SetStat(ctx, fs, creds, vfs.SetStatOptions{ Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor), }); err != nil { return linux.FUSEAttr{}, err @@ -812,7 +812,7 @@ func (i *inode) setAttr(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre } // Set the metadata of kernfs.InodeAttrs. - if err := i.SetInodeStat(ctx, fs, creds, vfs.SetStatOptions{ + if err := i.InodeAttrs.SetStat(ctx, fs, creds, vfs.SetStatOptions{ Stat: statFromFUSEAttr(out.Attr, linux.STATX_ALL, i.fs.devMinor), }); err != nil { return err diff --git a/pkg/sentry/fsimpl/fuse/read_write.go b/pkg/sentry/fsimpl/fuse/read_write.go index 625d1547f..2d396e84c 100644 --- a/pkg/sentry/fsimpl/fuse/read_write.go +++ b/pkg/sentry/fsimpl/fuse/read_write.go @@ -132,7 +132,7 @@ func (fs *filesystem) ReadCallback(ctx context.Context, fd *regularFileFD, off u // May need to update the signature. i := fd.inode() - // TODO(gvisor.dev/issue/1193): Invalidate or update atime. + i.InodeAttrs.TouchAtime(ctx, fd.vfsfd.Mount()) // Reached EOF. if sizeRead < size { @@ -179,6 +179,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, Flags: fd.statusFlags(), } + inode := fd.inode() var written uint32 // This loop is intended for fragmented write where the bytes to write is @@ -203,7 +204,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, in.Offset = off + uint64(written) in.Size = toWrite - req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), fd.inode().nodeID, linux.FUSE_WRITE, &in) + req, err := fs.conn.NewRequest(auth.CredentialsFromContext(ctx), uint32(t.ThreadID()), inode.nodeID, linux.FUSE_WRITE, &in) if err != nil { return 0, err } @@ -237,6 +238,7 @@ func (fs *filesystem) Write(ctx context.Context, fd *regularFileFD, off uint64, break } } + inode.InodeAttrs.TouchCMtime(ctx) return written, nil } diff --git a/pkg/sentry/fsimpl/gofer/BUILD b/pkg/sentry/fsimpl/gofer/BUILD index ad0afc41b..4c3e9acf8 100644 --- a/pkg/sentry/fsimpl/gofer/BUILD +++ b/pkg/sentry/fsimpl/gofer/BUILD @@ -38,6 +38,7 @@ go_library( "host_named_pipe.go", "p9file.go", "regular_file.go", + "save_restore.go", "socket.go", "special_file.go", "symlink.go", @@ -53,6 +54,7 @@ go_library( "//pkg/log", "//pkg/p9", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/lock", @@ -70,6 +72,7 @@ go_library( "//pkg/sentry/socket/unix/transport", "//pkg/sentry/usage", "//pkg/sentry/vfs", + "//pkg/sync", "//pkg/syserr", "//pkg/syserror", "//pkg/unet", diff --git a/pkg/sentry/fsimpl/gofer/directory.go b/pkg/sentry/fsimpl/gofer/directory.go index 18c884b59..ce1b2a390 100644 --- a/pkg/sentry/fsimpl/gofer/directory.go +++ b/pkg/sentry/fsimpl/gofer/directory.go @@ -16,16 +16,17 @@ package gofer import ( "fmt" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -92,7 +93,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { child := &dentry{ refs: 1, // held by d fs: d.fs, - ino: d.fs.nextSyntheticIno(), + ino: d.fs.nextIno(), mode: uint32(opts.mode), uid: uint32(opts.kuid), gid: uint32(opts.kgid), @@ -100,6 +101,7 @@ func (d *dentry) createSyntheticChildLocked(opts *createSyntheticOpts) { hostFD: -1, nlink: uint32(2), } + refsvfs2.Register(child) switch opts.mode.FileType() { case linux.S_IFDIR: // Nothing else needs to be done. @@ -235,7 +237,7 @@ func (d *dentry) getDirents(ctx context.Context) ([]vfs.Dirent, error) { } dirent := vfs.Dirent{ Name: p9d.Name, - Ino: uint64(inoFromPath(p9d.QID.Path)), + Ino: d.fs.inoFromQIDPath(p9d.QID.Path), NextOff: int64(len(dirents) + 1), } // p9 does not expose 9P2000.U's DMDEVICE, DMNAMEDPIPE, or diff --git a/pkg/sentry/fsimpl/gofer/filesystem.go b/pkg/sentry/fsimpl/gofer/filesystem.go index 94d96261b..bbb01148b 100644 --- a/pkg/sentry/fsimpl/gofer/filesystem.go +++ b/pkg/sentry/fsimpl/gofer/filesystem.go @@ -30,12 +30,11 @@ import ( "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) // Sync implements vfs.FilesystemImpl.Sync. func (fs *filesystem) Sync(ctx context.Context) error { - // Snapshot current syncable dentries and special files. + // Snapshot current syncable dentries and special file FDs. fs.syncMu.Lock() ds := make([]*dentry, 0, len(fs.syncableDentries)) for d := range fs.syncableDentries { @@ -53,22 +52,28 @@ func (fs *filesystem) Sync(ctx context.Context) error { // regardless. var retErr error - // Sync regular files. + // Sync syncable dentries. for _, d := range ds { - err := d.syncCachedFile(ctx) + err := d.syncCachedFile(ctx, true /* forFilesystemSync */) d.DecRef(ctx) - if err != nil && retErr == nil { - retErr = err + if err != nil { + ctx.Infof("gofer.filesystem.Sync: dentry.syncCachedFile failed: %v", err) + if retErr == nil { + retErr = err + } } } // Sync special files, which may be writable but do not use dentry shared // handles (so they won't be synced by the above). for _, sffd := range sffds { - err := sffd.Sync(ctx) + err := sffd.sync(ctx, true /* forFilesystemSync */) sffd.vfsfd.DecRef(ctx) - if err != nil && retErr == nil { - retErr = err + if err != nil { + ctx.Infof("gofer.filesystem.Sync: specialFileFD.sync failed: %v", err) + if retErr == nil { + retErr = err + } } } @@ -229,7 +234,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir return nil, err } if child != nil { - if !file.isNil() && inoFromPath(qid.Path) == child.ino { + if !file.isNil() && qid.Path == child.qidPath { // The file at this path hasn't changed. Just update cached metadata. file.close(ctx) child.updateFromP9AttrsLocked(attrMask, &attr) @@ -256,7 +261,7 @@ func (fs *filesystem) revalidateChildLocked(ctx context.Context, vfsObj *vfs.Vir // treat their invalidation as deletion. child.setDeleted() parent.syntheticChildren-- - child.decRefLocked() + child.decRefNoCaching() parent.dirents = nil } *ds = appendDentry(*ds, child) @@ -366,9 +371,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if len(name) > maxFilenameLen { return syserror.ENAMETOOLONG } - if !dir && rp.MustBeDir() { - return syserror.ENOENT - } if parent.isDeleted() { return syserror.ENOENT } @@ -383,6 +385,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if child := parent.children[name]; child != nil { return syserror.EEXIST } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } if createInSyntheticDir == nil { return syserror.EPERM } @@ -402,6 +407,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if child := parent.children[name]; child != nil && child.isSynthetic() { return syserror.EEXIST } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } // The existence of a non-synthetic dentry at name would be inconclusive // because the file it represents may have been deleted from the remote // filesystem, so we would need to make an RPC to revalidate the dentry. @@ -422,6 +430,9 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if child := parent.children[name]; child != nil { return syserror.EEXIST } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } // No cached dentry exists; however, there might still be an existing file // at name. As above, we attempt the file creation RPC anyway. if err := createInRemoteDir(parent, name, &ds); err != nil { @@ -625,7 +636,7 @@ func (fs *filesystem) unlinkAt(ctx context.Context, rp *vfs.ResolvingPath, dir b child.setDeleted() if child.isSynthetic() { parent.syntheticChildren-- - child.decRefLocked() + child.decRefNoCaching() } ds = appendDentry(ds, child) } @@ -836,7 +847,7 @@ func (fs *filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v mode: opts.Mode, kuid: creds.EffectiveKUID, kgid: creds.EffectiveKGID, - pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize), + pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize), }) return nil } @@ -1355,7 +1366,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa replaced.setDeleted() if replaced.isSynthetic() { newParent.syntheticChildren-- - replaced.decRefLocked() + replaced.decRefNoCaching() } ds = appendDentry(ds, replaced) } @@ -1364,7 +1375,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // with reference counts and queue oldParent for checkCachingLocked if the // parent isn't actually changing. if oldParent != newParent { - oldParent.decRefLocked() + oldParent.decRefNoCaching() ds = appendDentry(ds, oldParent) newParent.IncRef() if renamed.isSynthetic() { @@ -1512,7 +1523,6 @@ func (fs *filesystem) BoundEndpointAt(ctx context.Context, rp *vfs.ResolvingPath d.IncRef() return &endpoint{ dentry: d, - file: d.file.file, path: opts.Addr, }, nil } @@ -1591,7 +1601,3 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe defer fs.renameMu.RUnlock() return genericPrependPath(vfsroot, vd.Mount(), vd.Dentry().Impl().(*dentry), b) } - -func (fs *filesystem) nextSyntheticIno() inodeNumber { - return inodeNumber(atomic.AddUint64(&fs.syntheticSeq, 1) | syntheticInoMask) -} diff --git a/pkg/sentry/fsimpl/gofer/gofer.go b/pkg/sentry/fsimpl/gofer/gofer.go index f1dad1b08..6f82ce61b 100644 --- a/pkg/sentry/fsimpl/gofer/gofer.go +++ b/pkg/sentry/fsimpl/gofer/gofer.go @@ -26,6 +26,9 @@ // *** "memmap.Mappable locks taken by Translate" below this point // dentry.handleMu // dentry.dataMu +// filesystem.inoMu +// specialFileFD.mu +// specialFileFD.bufMu // // Locking dentry.dirMu in multiple dentries requires that either ancestor // dentries are locked before descendant dentries, or that filesystem.renameMu @@ -36,7 +39,6 @@ import ( "fmt" "strconv" "strings" - "sync" "sync/atomic" "syscall" @@ -44,6 +46,8 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/p9" + refs_vfs1 "gvisor.dev/gvisor/pkg/refs" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/fs/fsutil" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" @@ -53,6 +57,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/unet" "gvisor.dev/gvisor/pkg/usermem" @@ -81,7 +86,7 @@ type filesystem struct { iopts InternalFilesystemOptions // client is the client used by this filesystem. client is immutable. - client *p9.Client `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + client *p9.Client `state:"nosave"` // clock is a realtime clock used to set timestamps in file operations. clock ktime.Clock @@ -89,6 +94,9 @@ type filesystem struct { // devMinor is the filesystem's minor device number. devMinor is immutable. devMinor uint32 + // root is the root dentry. root is immutable. + root *dentry + // renameMu serves two purposes: // // - It synchronizes path resolution with renaming initiated by this @@ -103,39 +111,35 @@ type filesystem struct { // cachedDentries contains all dentries with 0 references. (Due to race // conditions, it may also contain dentries with non-zero references.) - // cachedDentriesLen is the number of dentries in cachedDentries. These - // fields are protected by renameMu. + // cachedDentriesLen is the number of dentries in cachedDentries. These fields + // are protected by renameMu. cachedDentries dentryList cachedDentriesLen uint64 - // syncableDentries contains all dentries in this filesystem for which - // !dentry.file.isNil(). specialFileFDs contains all open specialFileFDs. - // These fields are protected by syncMu. + // syncableDentries contains all non-synthetic dentries. specialFileFDs + // contains all open specialFileFDs. These fields are protected by syncMu. syncMu sync.Mutex `state:"nosave"` syncableDentries map[*dentry]struct{} specialFileFDs map[*specialFileFD]struct{} - // syntheticSeq stores a counter to used to generate unique inodeNumber for - // synthetic dentries. - syntheticSeq uint64 -} + // inoByQIDPath maps previously-observed QID.Paths to inode numbers + // assigned to those paths. inoByQIDPath is not preserved across + // checkpoint/restore because QIDs may be reused between different gofer + // processes, so QIDs may be repeated for different files across + // checkpoint/restore. inoByQIDPath is protected by inoMu. + inoMu sync.Mutex `state:"nosave"` + inoByQIDPath map[uint64]uint64 `state:"nosave"` -// inodeNumber represents inode number reported in Dirent.Ino. For regular -// dentries, it comes from QID.Path from the 9P server. Synthetic dentries -// have have their inodeNumber generated sequentially, with the MSB reserved to -// prevent conflicts with regular dentries. -// -// +stateify savable -type inodeNumber uint64 + // lastIno is the last inode number assigned to a file. lastIno is accessed + // using atomic memory operations. + lastIno uint64 -// Reserve MSB for synthetic mounts. -const syntheticInoMask = uint64(1) << 63 + // savedDentryRW records open read/write handles during save/restore. + savedDentryRW map[*dentry]savedDentryRW -func inoFromPath(path uint64) inodeNumber { - if path&syntheticInoMask != 0 { - log.Warningf("Dropping MSB from ino, collision is possible. Original: %d, new: %d", path, path&^syntheticInoMask) - } - return inodeNumber(path &^ syntheticInoMask) + // released is nonzero once filesystem.Release has been called. It is accessed + // with atomic memory operations. + released int32 } // +stateify savable @@ -149,8 +153,7 @@ type filesystemOptions struct { msize uint32 version string - // maxCachedDentries is the maximum number of dentries with 0 references - // retained by the client. + // maxCachedDentries is the maximum size of filesystem.cachedDentries. maxCachedDentries uint64 // If forcePageCache is true, host FDs may not be used for application @@ -247,6 +250,10 @@ const ( // // +stateify savable type InternalFilesystemOptions struct { + // If UniqueID is non-empty, it is an opaque string used to reassociate the + // filesystem with a new server FD during restoration from checkpoint. + UniqueID string + // If LeakConnection is true, do not close the connection to the server // when the Filesystem is released. This is necessary for deployments in // which servers can handle only a single client and report failure if that @@ -286,46 +293,11 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt mopts := vfs.GenericParseMountOptions(opts.Data) var fsopts filesystemOptions - // Check that the transport is "fd". - trans, ok := mopts["trans"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: transport must be specified as 'trans=fd'") - return nil, nil, syserror.EINVAL - } - delete(mopts, "trans") - if trans != "fd" { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: unsupported transport: trans=%s", trans) - return nil, nil, syserror.EINVAL - } - - // Check that read and write FDs are provided and identical. - rfdstr, ok := mopts["rfdno"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD must be specified as 'rfdno=<file descriptor>") - return nil, nil, syserror.EINVAL - } - delete(mopts, "rfdno") - rfd, err := strconv.Atoi(rfdstr) + fd, err := getFDFromMountOptionsMap(ctx, mopts) if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid read FD: rfdno=%s", rfdstr) - return nil, nil, syserror.EINVAL - } - wfdstr, ok := mopts["wfdno"] - if !ok { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: write FD must be specified as 'wfdno=<file descriptor>") - return nil, nil, syserror.EINVAL - } - delete(mopts, "wfdno") - wfd, err := strconv.Atoi(wfdstr) - if err != nil { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: invalid write FD: wfdno=%s", wfdstr) - return nil, nil, syserror.EINVAL - } - if rfd != wfd { - ctx.Warningf("gofer.FilesystemType.GetFilesystem: read FD (%d) and write FD (%d) must be equal", rfd, wfd) - return nil, nil, syserror.EINVAL + return nil, nil, err } - fsopts.fd = rfd + fsopts.fd = fd // Get the attach name. fsopts.aname = "/" @@ -441,57 +413,44 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } // If !ok, iopts being the zero value is correct. - // Establish a connection with the server. - conn, err := unet.NewSocket(fsopts.fd) + // Construct the filesystem object. + devMinor, err := vfsObj.GetAnonBlockDevMinor() if err != nil { return nil, nil, err } + fs := &filesystem{ + mfp: mfp, + opts: fsopts, + iopts: iopts, + clock: ktime.RealtimeClockFromContext(ctx), + devMinor: devMinor, + syncableDentries: make(map[*dentry]struct{}), + specialFileFDs: make(map[*specialFileFD]struct{}), + inoByQIDPath: make(map[uint64]uint64), + } + fs.vfsfs.Init(vfsObj, &fstype, fs) - // Perform version negotiation with the server. - ctx.UninterruptibleSleepStart(false) - client, err := p9.NewClient(conn, fsopts.msize, fsopts.version) - ctx.UninterruptibleSleepFinish(false) - if err != nil { - conn.Close() + // Connect to the server. + if err := fs.dial(ctx); err != nil { return nil, nil, err } - // Ownership of conn has been transferred to client. // Perform attach to obtain the filesystem root. ctx.UninterruptibleSleepStart(false) - attached, err := client.Attach(fsopts.aname) + attached, err := fs.client.Attach(fsopts.aname) ctx.UninterruptibleSleepFinish(false) if err != nil { - client.Close() + fs.vfsfs.DecRef(ctx) return nil, nil, err } attachFile := p9file{attached} qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) if err != nil { attachFile.close(ctx) - client.Close() + fs.vfsfs.DecRef(ctx) return nil, nil, err } - // Construct the filesystem object. - devMinor, err := vfsObj.GetAnonBlockDevMinor() - if err != nil { - attachFile.close(ctx) - client.Close() - return nil, nil, err - } - fs := &filesystem{ - mfp: mfp, - opts: fsopts, - iopts: iopts, - client: client, - clock: ktime.RealtimeClockFromContext(ctx), - devMinor: devMinor, - syncableDentries: make(map[*dentry]struct{}), - specialFileFDs: make(map[*specialFileFD]struct{}), - } - fs.vfsfs.Init(vfsObj, &fstype, fs) - // Construct the root dentry. root, err := fs.newDentry(ctx, attachFile, qid, attrMask, &attr) if err != nil { @@ -500,25 +459,87 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, err } // Set the root's reference count to 2. One reference is returned to the - // caller, and the other is deliberately leaked to prevent the root from - // being "cached" and subsequently evicted. Its resources will still be - // cleaned up by fs.Release(). + // caller, and the other is held by fs to prevent the root from being "cached" + // and subsequently evicted. root.refs = 2 + fs.root = root return &fs.vfsfs, &root.vfsd, nil } +func getFDFromMountOptionsMap(ctx context.Context, mopts map[string]string) (int, error) { + // Check that the transport is "fd". + trans, ok := mopts["trans"] + if !ok || trans != "fd" { + ctx.Warningf("gofer.getFDFromMountOptionsMap: transport must be specified as 'trans=fd'") + return -1, syserror.EINVAL + } + delete(mopts, "trans") + + // Check that read and write FDs are provided and identical. + rfdstr, ok := mopts["rfdno"] + if !ok { + ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD must be specified as 'rfdno=<file descriptor>'") + return -1, syserror.EINVAL + } + delete(mopts, "rfdno") + rfd, err := strconv.Atoi(rfdstr) + if err != nil { + ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid read FD: rfdno=%s", rfdstr) + return -1, syserror.EINVAL + } + wfdstr, ok := mopts["wfdno"] + if !ok { + ctx.Warningf("gofer.getFDFromMountOptionsMap: write FD must be specified as 'wfdno=<file descriptor>'") + return -1, syserror.EINVAL + } + delete(mopts, "wfdno") + wfd, err := strconv.Atoi(wfdstr) + if err != nil { + ctx.Warningf("gofer.getFDFromMountOptionsMap: invalid write FD: wfdno=%s", wfdstr) + return -1, syserror.EINVAL + } + if rfd != wfd { + ctx.Warningf("gofer.getFDFromMountOptionsMap: read FD (%d) and write FD (%d) must be equal", rfd, wfd) + return -1, syserror.EINVAL + } + return rfd, nil +} + +// Preconditions: fs.client == nil. +func (fs *filesystem) dial(ctx context.Context) error { + // Establish a connection with the server. + conn, err := unet.NewSocket(fs.opts.fd) + if err != nil { + return err + } + + // Perform version negotiation with the server. + ctx.UninterruptibleSleepStart(false) + client, err := p9.NewClient(conn, fs.opts.msize, fs.opts.version) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + conn.Close() + return err + } + // Ownership of conn has been transferred to client. + + fs.client = client + return nil +} + // Release implements vfs.FilesystemImpl.Release. func (fs *filesystem) Release(ctx context.Context) { - mf := fs.mfp.MemoryFile() + atomic.StoreInt32(&fs.released, 1) + mf := fs.mfp.MemoryFile() fs.syncMu.Lock() for d := range fs.syncableDentries { d.handleMu.Lock() d.dataMu.Lock() if h := d.writeHandleLocked(); h.isOpen() { // Write dirty cached data to the remote file. - if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, fs.mfp.MemoryFile(), h.writeFromBlocksAt); err != nil { + if err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, mf, h.writeFromBlocksAt); err != nil { log.Warningf("gofer.filesystem.Release: failed to flush dentry: %v", err) } // TODO(jamieliu): Do we need to flushf/fsync d? @@ -539,6 +560,21 @@ func (fs *filesystem) Release(ctx context.Context) { // fs. fs.syncMu.Unlock() + // If leak checking is enabled, release all outstanding references in the + // filesystem. We deliberately avoid doing this outside of leak checking; we + // have released all external resources above rather than relying on dentry + // destructors. + if refs_vfs1.GetLeakMode() != refs_vfs1.NoLeakChecking { + fs.renameMu.Lock() + fs.root.releaseSyntheticRecursiveLocked(ctx) + fs.evictAllCachedDentriesLocked(ctx) + fs.renameMu.Unlock() + + // An extra reference was held by the filesystem on the root to prevent it from + // being cached/evicted. + fs.root.DecRef(ctx) + } + if !fs.iopts.LeakConnection { // Close the connection to the server. This implicitly clunks all fids. fs.client.Close() @@ -547,6 +583,31 @@ func (fs *filesystem) Release(ctx context.Context) { fs.vfsfs.VirtualFilesystem().PutAnonBlockDevMinor(fs.devMinor) } +// releaseSyntheticRecursiveLocked traverses the tree with root d and decrements +// the reference count on every synthetic dentry. Synthetic dentries have one +// reference for existence that should be dropped during filesystem.Release. +// +// Precondition: d.fs.renameMu is locked. +func (d *dentry) releaseSyntheticRecursiveLocked(ctx context.Context) { + if d.isSynthetic() { + d.decRefNoCaching() + d.checkCachingLocked(ctx) + } + if d.isDir() { + var children []*dentry + d.dirMu.Lock() + for _, child := range d.children { + children = append(children, child) + } + d.dirMu.Unlock() + for _, child := range children { + if child != nil { + child.releaseSyntheticRecursiveLocked(ctx) + } + } + } +} + // dentry implements vfs.DentryImpl. // // +stateify savable @@ -574,12 +635,15 @@ type dentry struct { // filesystem.renameMu. name string + // qidPath is the p9.QID.Path for this file. qidPath is immutable. + qidPath uint64 + // file is the unopened p9.File that backs this dentry. file is immutable. // // If file.isNil(), this dentry represents a synthetic file, i.e. a file // that does not exist on the remote filesystem. As of this writing, the // only files that can be synthetic are sockets, pipes, and directories. - file p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + file p9file `state:"nosave"` // If deleted is non-zero, the file represented by this dentry has been // deleted. deleted is accessed using atomic memory operations. @@ -623,12 +687,12 @@ type dentry struct { // To mutate: // - Lock metadataMu and use atomic operations to update because we might // have atomic readers that don't hold the lock. - metadataMu sync.Mutex `state:"nosave"` - ino inodeNumber // immutable - mode uint32 // type is immutable, perms are mutable - uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic - gid uint32 // auth.KGID, but ... - blockSize uint32 // 0 if unknown + metadataMu sync.Mutex `state:"nosave"` + ino uint64 // immutable + mode uint32 // type is immutable, perms are mutable + uid uint32 // auth.KUID, but stored as raw uint32 for sync/atomic + gid uint32 // auth.KGID, but ... + blockSize uint32 // 0 if unknown // Timestamps, all nsecs from the Unix epoch. atime int64 mtime int64 @@ -679,9 +743,9 @@ type dentry struct { // (isNil() == false), it may be mutated with handleMu locked, but cannot // be closed until the dentry is destroyed. handleMu sync.RWMutex `state:"nosave"` - readFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. - writeFile p9file `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. - hostFD int32 + readFile p9file `state:"nosave"` + writeFile p9file `state:"nosave"` + hostFD int32 `state:"nosave"` dataMu sync.RWMutex `state:"nosave"` @@ -758,8 +822,9 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma d := &dentry{ fs: fs, + qidPath: qid.Path, file: file, - ino: inoFromPath(qid.Path), + ino: fs.inoFromQIDPath(qid.Path), mode: uint32(attr.Mode), uid: uint32(fs.opts.dfltuid), gid: uint32(fs.opts.dfltgid), @@ -795,13 +860,28 @@ func (fs *filesystem) newDentry(ctx context.Context, file p9file, qid p9.QID, ma d.nlink = uint32(attr.NLink) } d.vfsd.Init(d) - + refsvfs2.Register(d) fs.syncMu.Lock() fs.syncableDentries[d] = struct{}{} fs.syncMu.Unlock() return d, nil } +func (fs *filesystem) inoFromQIDPath(qidPath uint64) uint64 { + fs.inoMu.Lock() + defer fs.inoMu.Unlock() + if ino, ok := fs.inoByQIDPath[qidPath]; ok { + return ino + } + ino := fs.nextIno() + fs.inoByQIDPath[qidPath] = ino + return ino +} + +func (fs *filesystem) nextIno() uint64 { + return atomic.AddUint64(&fs.lastIno, 1) +} + func (d *dentry) isSynthetic() bool { return d.file.isNil() } @@ -853,7 +933,7 @@ func (d *dentry) updateFromP9AttrsLocked(mask p9.AttrMask, attr *p9.Attr) { } } -// Preconditions: !d.isSynthetic() +// Preconditions: !d.isSynthetic(). func (d *dentry) updateFromGetattr(ctx context.Context) error { // Use d.readFile or d.writeFile, which represent 9P fids that have been // opened, in preference to d.file, which represents a 9P fid that has not. @@ -916,10 +996,10 @@ func (d *dentry) statTo(stat *linux.Statx) { // This is consistent with regularFileFD.Seek(), which treats regular files // as having no holes. stat.Blocks = (stat.Size + 511) / 512 - stat.Atime = statxTimestampFromDentry(atomic.LoadInt64(&d.atime)) - stat.Btime = statxTimestampFromDentry(atomic.LoadInt64(&d.btime)) - stat.Ctime = statxTimestampFromDentry(atomic.LoadInt64(&d.ctime)) - stat.Mtime = statxTimestampFromDentry(atomic.LoadInt64(&d.mtime)) + stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.atime)) + stat.Btime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.btime)) + stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.ctime)) + stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&d.mtime)) stat.DevMajor = linux.UNNAMED_MAJOR stat.DevMinor = d.fs.devMinor } @@ -967,10 +1047,10 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // Use client clocks for timestamps. now = d.fs.clock.Now().Nanoseconds() if stat.Mask&linux.STATX_ATIME != 0 && stat.Atime.Nsec == linux.UTIME_NOW { - stat.Atime = statxTimestampFromDentry(now) + stat.Atime = linux.NsecToStatxTimestamp(now) } if stat.Mask&linux.STATX_MTIME != 0 && stat.Mtime.Nsec == linux.UTIME_NOW { - stat.Mtime = statxTimestampFromDentry(now) + stat.Mtime = linux.NsecToStatxTimestamp(now) } } @@ -1029,11 +1109,11 @@ func (d *dentry) setStat(ctx context.Context, creds *auth.Credentials, opts *vfs // !d.cachedMetadataAuthoritative() then we returned after calling // d.file.setAttr(). For the same reason, now must have been initialized. if stat.Mask&linux.STATX_ATIME != 0 { - atomic.StoreInt64(&d.atime, dentryTimestampFromStatx(stat.Atime)) + atomic.StoreInt64(&d.atime, stat.Atime.ToNsec()) atomic.StoreUint32(&d.atimeDirty, 0) } if stat.Mask&linux.STATX_MTIME != 0 { - atomic.StoreInt64(&d.mtime, dentryTimestampFromStatx(stat.Mtime)) + atomic.StoreInt64(&d.mtime, stat.Mtime.ToNsec()) atomic.StoreUint32(&d.mtimeDirty, 0) } atomic.StoreInt64(&d.ctime, now) @@ -1139,17 +1219,19 @@ func dentryGIDFromP9GID(gid p9.GID) uint32 { func (d *dentry) IncRef() { // d.refs may be 0 if d.fs.renameMu is locked, which serializes against // d.checkCachingLocked(). - atomic.AddInt64(&d.refs, 1) + r := atomic.AddInt64(&d.refs, 1) + refsvfs2.LogIncRef(d, r) } // TryIncRef implements vfs.DentryImpl.TryIncRef. func (d *dentry) TryIncRef() bool { for { - refs := atomic.LoadInt64(&d.refs) - if refs <= 0 { + r := atomic.LoadInt64(&d.refs) + if r <= 0 { return false } - if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) { + if atomic.CompareAndSwapInt64(&d.refs, r, r+1) { + refsvfs2.LogTryIncRef(d, r+1) return true } } @@ -1157,22 +1239,41 @@ func (d *dentry) TryIncRef() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *dentry) DecRef(ctx context.Context) { - if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { + if d.decRefNoCaching() == 0 { d.fs.renameMu.Lock() d.checkCachingLocked(ctx) d.fs.renameMu.Unlock() - } else if refs < 0 { - panic("gofer.dentry.DecRef() called without holding a reference") } } -// decRefLocked decrements d's reference count without calling +// decRefNoCaching decrements d's reference count without calling // d.checkCachingLocked, even if d's reference count reaches 0; callers are // responsible for ensuring that d.checkCachingLocked will be called later. -func (d *dentry) decRefLocked() { - if refs := atomic.AddInt64(&d.refs, -1); refs < 0 { - panic("gofer.dentry.decRefLocked() called without holding a reference") +func (d *dentry) decRefNoCaching() int64 { + r := atomic.AddInt64(&d.refs, -1) + refsvfs2.LogDecRef(d, r) + if r < 0 { + panic("gofer.dentry.decRefNoCaching() called without holding a reference") } + return r +} + +// RefType implements refsvfs2.CheckedObject.Type. +func (d *dentry) RefType() string { + return "gofer.dentry" +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *dentry) LeakMessage() string { + return fmt.Sprintf("[gofer.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +// +// This should only be set to true for debugging purposes, as it can generate an +// extremely large amount of output and drastically degrade performance. +func (d *dentry) LogRefs() bool { + return false } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. @@ -1223,6 +1324,10 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { // resolution, which requires renameMu, so if d.refs is zero then it will // remain zero while we hold renameMu for writing.) refs := atomic.LoadInt64(&d.refs) + if refs == -1 { + // Dentry has already been destroyed. + return + } if refs > 0 { if d.cached { d.fs.cachedDentries.Remove(d) @@ -1231,10 +1336,6 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { } return } - if refs == -1 { - // Dentry has already been destroyed. - return - } // Deleted and invalidated dentries with zero references are no longer // reachable by path resolution and should be dropped immediately. if d.vfsd.IsDead() { @@ -1257,6 +1358,16 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { if d.watches.Size() > 0 { return } + + if atomic.LoadInt32(&d.fs.released) != 0 { + if d.parent != nil { + d.parent.dirMu.Lock() + delete(d.parent.children, d.name) + d.parent.dirMu.Unlock() + } + d.destroyLocked(ctx) + } + // If d is already cached, just move it to the front of the LRU. if d.cached { d.fs.cachedDentries.Remove(d) @@ -1269,33 +1380,48 @@ func (d *dentry) checkCachingLocked(ctx context.Context) { d.fs.cachedDentriesLen++ d.cached = true if d.fs.cachedDentriesLen > d.fs.opts.maxCachedDentries { - victim := d.fs.cachedDentries.Back() - d.fs.cachedDentries.Remove(victim) - d.fs.cachedDentriesLen-- - victim.cached = false - // victim.refs may have become non-zero from an earlier path resolution - // since it was inserted into fs.cachedDentries. - if atomic.LoadInt64(&victim.refs) == 0 { - if victim.parent != nil { - victim.parent.dirMu.Lock() - if !victim.vfsd.IsDead() { - // Note that victim can't be a mount point (in any mount - // namespace), since VFS holds references on mount points. - d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) - delete(victim.parent.children, victim.name) - // We're only deleting the dentry, not the file it - // represents, so we don't need to update - // victimParent.dirents etc. - } - victim.parent.dirMu.Unlock() - } - victim.destroyLocked(ctx) - } + d.fs.evictCachedDentryLocked(ctx) // Whether or not victim was destroyed, we brought fs.cachedDentriesLen // back down to fs.opts.maxCachedDentries, so we don't loop. } } +// Precondition: fs.renameMu must be locked for writing; it may be temporarily +// unlocked. +func (fs *filesystem) evictAllCachedDentriesLocked(ctx context.Context) { + for fs.cachedDentriesLen != 0 { + fs.evictCachedDentryLocked(ctx) + } +} + +// Preconditions: +// * fs.renameMu must be locked for writing; it may be temporarily unlocked. +// * fs.cachedDentriesLen != 0. +func (fs *filesystem) evictCachedDentryLocked(ctx context.Context) { + victim := fs.cachedDentries.Back() + fs.cachedDentries.Remove(victim) + fs.cachedDentriesLen-- + victim.cached = false + // victim.refs may have become non-zero from an earlier path resolution + // since it was inserted into fs.cachedDentries. + if atomic.LoadInt64(&victim.refs) == 0 { + if victim.parent != nil { + victim.parent.dirMu.Lock() + if !victim.vfsd.IsDead() { + // Note that victim can't be a mount point (in any mount + // namespace), since VFS holds references on mount points. + fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, &victim.vfsd) + delete(victim.parent.children, victim.name) + // We're only deleting the dentry, not the file it + // represents, so we don't need to update + // victimParent.dirents etc. + } + victim.parent.dirMu.Unlock() + } + victim.destroyLocked(ctx) + } +} + // destroyLocked destroys the dentry. // // Preconditions: @@ -1373,13 +1499,10 @@ func (d *dentry) destroyLocked(ctx context.Context) { // Drop the reference held by d on its parent without recursively locking // d.fs.renameMu. - if d.parent != nil { - if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 { - d.parent.checkCachingLocked(ctx) - } else if refs < 0 { - panic("gofer.dentry.DecRef() called without holding a reference") - } + if d.parent != nil && d.parent.decRefNoCaching() == 0 { + d.parent.checkCachingLocked(ctx) } + refsvfs2.Unregister(d) } func (d *dentry) isDeleted() bool { @@ -1623,6 +1746,33 @@ func (d *dentry) syncRemoteFileLocked(ctx context.Context) error { return nil } +func (d *dentry) syncCachedFile(ctx context.Context, forFilesystemSync bool) error { + d.handleMu.RLock() + defer d.handleMu.RUnlock() + h := d.writeHandleLocked() + if h.isOpen() { + // Write back dirty pages to the remote file. + d.dataMu.Lock() + err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt) + d.dataMu.Unlock() + if err != nil { + return err + } + } + if err := d.syncRemoteFileLocked(ctx); err != nil { + if !forFilesystemSync { + return err + } + // Only return err if we can reasonably have expected sync to succeed + // (d is a regular file and was opened for writing). + if d.isRegularFile() && h.isOpen() { + return err + } + ctx.Debugf("gofer.dentry.syncCachedFile: syncing non-writable or non-regular-file dentry failed: %v", err) + } + return nil +} + // incLinks increments link count. func (d *dentry) incLinks() { if atomic.LoadUint32(&d.nlink) == 0 { @@ -1650,7 +1800,7 @@ type fileDescription struct { vfs.FileDescriptionDefaultImpl vfs.LockFD - lockLogging sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + lockLogging sync.Once `state:"nosave"` } func (fd *fileDescription) filesystem() *filesystem { diff --git a/pkg/sentry/fsimpl/gofer/gofer_test.go b/pkg/sentry/fsimpl/gofer/gofer_test.go index bfe75dfe4..76f08e252 100644 --- a/pkg/sentry/fsimpl/gofer/gofer_test.go +++ b/pkg/sentry/fsimpl/gofer/gofer_test.go @@ -26,12 +26,13 @@ import ( func TestDestroyIdempotent(t *testing.T) { ctx := contexttest.Context(t) fs := filesystem{ - mfp: pgalloc.MemoryFileProviderFromContext(ctx), - syncableDentries: make(map[*dentry]struct{}), + mfp: pgalloc.MemoryFileProviderFromContext(ctx), opts: filesystemOptions{ // Test relies on no dentry being held in the cache. maxCachedDentries: 0, }, + syncableDentries: make(map[*dentry]struct{}), + inoByQIDPath: make(map[uint64]uint64), } attr := &p9.Attr{ diff --git a/pkg/sentry/fsimpl/gofer/host_named_pipe.go b/pkg/sentry/fsimpl/gofer/host_named_pipe.go index 7294de7d6..c7bf10007 100644 --- a/pkg/sentry/fsimpl/gofer/host_named_pipe.go +++ b/pkg/sentry/fsimpl/gofer/host_named_pipe.go @@ -51,8 +51,24 @@ func blockUntilNonblockingPipeHasWriter(ctx context.Context, fd int32) error { if ok { return nil } - if err := sleepBetweenNamedPipeOpenChecks(ctx); err != nil { - return err + if sleepErr := sleepBetweenNamedPipeOpenChecks(ctx); sleepErr != nil { + // Another application thread may have opened this pipe for + // writing, succeeded because we previously opened the pipe for + // reading, and subsequently interrupted us for checkpointing (e.g. + // this occurs in mknod tests under cooperative save/restore). In + // this case, our open has to succeed for the checkpoint to include + // a readable FD for the pipe, which is in turn necessary to + // restore the other thread's writable FD for the same pipe + // (otherwise it will get ENXIO). So we have to check + // nonblockingPipeHasWriter() once last time. + ok, err := nonblockingPipeHasWriter(fd) + if err != nil { + return err + } + if ok { + return nil + } + return sleepErr } } } diff --git a/pkg/sentry/fsimpl/gofer/regular_file.go b/pkg/sentry/fsimpl/gofer/regular_file.go index f8b19bae7..dc8a890cb 100644 --- a/pkg/sentry/fsimpl/gofer/regular_file.go +++ b/pkg/sentry/fsimpl/gofer/regular_file.go @@ -18,7 +18,6 @@ import ( "fmt" "io" "math" - "sync" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" @@ -31,6 +30,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/pgalloc" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -624,23 +624,7 @@ func regularFileSeekLocked(ctx context.Context, d *dentry, fdOffset, offset int6 // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *regularFileFD) Sync(ctx context.Context) error { - return fd.dentry().syncCachedFile(ctx) -} - -func (d *dentry) syncCachedFile(ctx context.Context) error { - d.handleMu.RLock() - defer d.handleMu.RUnlock() - - if h := d.writeHandleLocked(); h.isOpen() { - d.dataMu.Lock() - // Write dirty cached data to the remote file. - err := fsutil.SyncDirtyAll(ctx, &d.cache, &d.dirty, d.size, d.fs.mfp.MemoryFile(), h.writeFromBlocksAt) - d.dataMu.Unlock() - if err != nil { - return err - } - } - return d.syncRemoteFileLocked(ctx) + return fd.dentry().syncCachedFile(ctx, false /* lowSyncExpectations */) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. @@ -913,7 +897,7 @@ type dentryPlatformFile struct { hostFileMapper fsutil.HostFileMapper // hostFileMapperInitOnce is used to lazily initialize hostFileMapper. - hostFileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + hostFileMapperInitOnce sync.Once `state:"nosave"` } // IncRef implements memmap.File.IncRef. diff --git a/pkg/sentry/fsimpl/gofer/save_restore.go b/pkg/sentry/fsimpl/gofer/save_restore.go new file mode 100644 index 000000000..17849dcc0 --- /dev/null +++ b/pkg/sentry/fsimpl/gofer/save_restore.go @@ -0,0 +1,329 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gofer + +import ( + "fmt" + "io" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/abi/linux" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/p9" + "gvisor.dev/gvisor/pkg/refsvfs2" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" +) + +type saveRestoreContextID int + +const ( + // CtxRestoreServerFDMap is a Context.Value key for a map[string]int + // mapping filesystem unique IDs (cf. InternalFilesystemOptions.UniqueID) + // to host FDs. + CtxRestoreServerFDMap saveRestoreContextID = iota +) + +// +stateify savable +type savedDentryRW struct { + read bool + write bool +} + +// PreprareSave implements vfs.FilesystemImplSaveRestoreExtension.PrepareSave. +func (fs *filesystem) PrepareSave(ctx context.Context) error { + if len(fs.iopts.UniqueID) == 0 { + return fmt.Errorf("gofer.filesystem with no UniqueID cannot be saved") + } + + // Purge cached dentries, which may not be reopenable after restore due to + // permission changes. + fs.renameMu.Lock() + fs.evictAllCachedDentriesLocked(ctx) + fs.renameMu.Unlock() + + // Buffer pipe data so that it's available for reading after restore. (This + // is a legacy VFS1 feature.) + fs.syncMu.Lock() + for sffd := range fs.specialFileFDs { + if sffd.dentry().fileType() == linux.S_IFIFO && sffd.vfsfd.IsReadable() { + if err := sffd.savePipeData(ctx); err != nil { + fs.syncMu.Unlock() + return err + } + } + } + fs.syncMu.Unlock() + + // Flush local state to the remote filesystem. + if err := fs.Sync(ctx); err != nil { + return err + } + + fs.savedDentryRW = make(map[*dentry]savedDentryRW) + return fs.root.prepareSaveRecursive(ctx) +} + +// Preconditions: +// * fd represents a pipe. +// * fd is readable. +func (fd *specialFileFD) savePipeData(ctx context.Context) error { + fd.bufMu.Lock() + defer fd.bufMu.Unlock() + var buf [usermem.PageSize]byte + for { + n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), ^uint64(0)) + if n != 0 { + fd.buf = append(fd.buf, buf[:n]...) + } + if err != nil { + if err == io.EOF || err == syserror.EAGAIN { + break + } + return err + } + } + if len(fd.buf) != 0 { + atomic.StoreUint32(&fd.haveBuf, 1) + } + return nil +} + +func (d *dentry) prepareSaveRecursive(ctx context.Context) error { + if d.isRegularFile() && !d.cachedMetadataAuthoritative() { + // Get updated metadata for d in case we need to perform metadata + // validation during restore. + if err := d.updateFromGetattr(ctx); err != nil { + return err + } + } + if !d.readFile.isNil() || !d.writeFile.isNil() { + d.fs.savedDentryRW[d] = savedDentryRW{ + read: !d.readFile.isNil(), + write: !d.writeFile.isNil(), + } + } + d.dirMu.Lock() + defer d.dirMu.Unlock() + for _, child := range d.children { + if child != nil { + if err := child.prepareSaveRecursive(ctx); err != nil { + return err + } + } + } + return nil +} + +// beforeSave is invoked by stateify. +func (d *dentry) beforeSave() { + if d.vfsd.IsDead() { + panic(fmt.Sprintf("gofer.dentry(%q).beforeSave: deleted and invalidated dentries can't be restored", genericDebugPathname(d))) + } +} + +// afterLoad is invoked by stateify. +func (d *dentry) afterLoad() { + d.hostFD = -1 + if atomic.LoadInt64(&d.refs) != -1 { + refsvfs2.Register(d) + } +} + +// afterLoad is invoked by stateify. +func (d *dentryPlatformFile) afterLoad() { + if d.hostFileMapper.IsInited() { + // Ensure that we don't call d.hostFileMapper.Init() again. + d.hostFileMapperInitOnce.Do(func() {}) + } +} + +// afterLoad is invoked by stateify. +func (fd *specialFileFD) afterLoad() { + fd.handle.fd = -1 +} + +// CompleteRestore implements +// vfs.FilesystemImplSaveRestoreExtension.CompleteRestore. +func (fs *filesystem) CompleteRestore(ctx context.Context, opts vfs.CompleteRestoreOptions) error { + fdmapv := ctx.Value(CtxRestoreServerFDMap) + if fdmapv == nil { + return fmt.Errorf("no server FD map available") + } + fdmap := fdmapv.(map[string]int) + fd, ok := fdmap[fs.iopts.UniqueID] + if !ok { + return fmt.Errorf("no server FD available for filesystem with unique ID %q", fs.iopts.UniqueID) + } + fs.opts.fd = fd + if err := fs.dial(ctx); err != nil { + return err + } + fs.inoByQIDPath = make(map[uint64]uint64) + + // Restore the filesystem root. + ctx.UninterruptibleSleepStart(false) + attached, err := fs.client.Attach(fs.opts.aname) + ctx.UninterruptibleSleepFinish(false) + if err != nil { + return err + } + attachFile := p9file{attached} + qid, attrMask, attr, err := attachFile.getAttr(ctx, dentryAttrMask()) + if err != nil { + return err + } + if err := fs.root.restoreFile(ctx, attachFile, qid, attrMask, &attr, &opts); err != nil { + return err + } + + // Restore remaining dentries. + if err := fs.root.restoreDescendantsRecursive(ctx, &opts); err != nil { + return err + } + + // Re-open handles for specialFileFDs. Unlike the initial open + // (dentry.openSpecialFile()), pipes are always opened without blocking; + // non-readable pipe FDs are opened last to ensure that they don't get + // ENXIO if another specialFileFD represents the read end of the same pipe. + // This is consistent with VFS1. + haveWriteOnlyPipes := false + for fd := range fs.specialFileFDs { + if fd.dentry().fileType() == linux.S_IFIFO && !fd.vfsfd.IsReadable() { + haveWriteOnlyPipes = true + continue + } + if err := fd.completeRestore(ctx); err != nil { + return err + } + } + if haveWriteOnlyPipes { + for fd := range fs.specialFileFDs { + if fd.dentry().fileType() == linux.S_IFIFO && !fd.vfsfd.IsReadable() { + if err := fd.completeRestore(ctx); err != nil { + return err + } + } + } + } + + // Discard state only required during restore. + fs.savedDentryRW = nil + + return nil +} + +func (d *dentry) restoreFile(ctx context.Context, file p9file, qid p9.QID, attrMask p9.AttrMask, attr *p9.Attr, opts *vfs.CompleteRestoreOptions) error { + d.file = file + + // Gofers do not preserve QID across checkpoint/restore, so: + // + // - We must assume that the remote filesystem did not change in a way that + // would invalidate dentries, since we can't revalidate dentries by + // checking QIDs. + // + // - We need to associate the new QID.Path with the existing d.ino. + d.qidPath = qid.Path + d.fs.inoMu.Lock() + d.fs.inoByQIDPath[qid.Path] = d.ino + d.fs.inoMu.Unlock() + + // Check metadata stability before updating metadata. + d.metadataMu.Lock() + defer d.metadataMu.Unlock() + if d.isRegularFile() { + if opts.ValidateFileSizes { + if !attrMask.Size { + return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: file size not available", genericDebugPathname(d)) + } + if d.size != attr.Size { + return fmt.Errorf("gofer.dentry(%q).restoreFile: file size validation failed: size changed from %d to %d", genericDebugPathname(d), d.size, attr.Size) + } + } + if opts.ValidateFileModificationTimestamps { + if !attrMask.MTime { + return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime not available", genericDebugPathname(d)) + } + if want := dentryTimestampFromP9(attr.MTimeSeconds, attr.MTimeNanoSeconds); d.mtime != want { + return fmt.Errorf("gofer.dentry(%q).restoreFile: mtime validation failed: mtime changed from %+v to %+v", genericDebugPathname(d), linux.NsecToStatxTimestamp(d.mtime), linux.NsecToStatxTimestamp(want)) + } + } + } + if !d.cachedMetadataAuthoritative() { + d.updateFromP9AttrsLocked(attrMask, attr) + } + + if rw, ok := d.fs.savedDentryRW[d]; ok { + if err := d.ensureSharedHandle(ctx, rw.read, rw.write, false /* trunc */); err != nil { + return err + } + } + + return nil +} + +// Preconditions: d is not synthetic. +func (d *dentry) restoreDescendantsRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error { + for _, child := range d.children { + if child == nil { + continue + } + if _, ok := d.fs.syncableDentries[child]; !ok { + // child is synthetic. + continue + } + if err := child.restoreRecursive(ctx, opts); err != nil { + return err + } + } + return nil +} + +// Preconditions: d is not synthetic (but note that since this function +// restores d.file, d.file.isNil() is always true at this point, so this can +// only be detected by checking filesystem.syncableDentries). d.parent has been +// restored. +func (d *dentry) restoreRecursive(ctx context.Context, opts *vfs.CompleteRestoreOptions) error { + qid, file, attrMask, attr, err := d.parent.file.walkGetAttrOne(ctx, d.name) + if err != nil { + return err + } + if err := d.restoreFile(ctx, file, qid, attrMask, &attr, opts); err != nil { + return err + } + return d.restoreDescendantsRecursive(ctx, opts) +} + +func (fd *specialFileFD) completeRestore(ctx context.Context) error { + d := fd.dentry() + h, err := openHandle(ctx, d.file, fd.vfsfd.IsReadable(), fd.vfsfd.IsWritable(), false /* trunc */) + if err != nil { + return err + } + fd.handle = h + + ftype := d.fileType() + fd.haveQueue = (ftype == linux.S_IFIFO || ftype == linux.S_IFSOCK) && fd.handle.fd >= 0 + if fd.haveQueue { + if err := fdnotifier.AddFD(fd.handle.fd, &fd.queue); err != nil { + return err + } + } + + return nil +} diff --git a/pkg/sentry/fsimpl/gofer/socket.go b/pkg/sentry/fsimpl/gofer/socket.go index 326b940a7..a21199eac 100644 --- a/pkg/sentry/fsimpl/gofer/socket.go +++ b/pkg/sentry/fsimpl/gofer/socket.go @@ -42,9 +42,6 @@ type endpoint struct { // dentry is the filesystem dentry which produced this endpoint. dentry *dentry - // file is the p9 file that contains a single unopened fid. - file p9.File `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. - // path is the sentry path where this endpoint is bound. path string } @@ -116,7 +113,7 @@ func (e *endpoint) UnidirectionalConnect(ctx context.Context) (transport.Connect } func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFlags, queue *waiter.Queue) (*host.SCMConnectedEndpoint, *syserr.Error) { - hostFile, err := e.file.Connect(flags) + hostFile, err := e.dentry.file.connect(ctx, flags) if err != nil { return nil, syserr.ErrConnectionRefused } @@ -131,7 +128,7 @@ func (e *endpoint) newConnectedEndpoint(ctx context.Context, flags p9.ConnectFla c, serr := host.NewSCMEndpoint(ctx, hostFD, queue, e.path) if serr != nil { - log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.file, flags, serr) + log.Warningf("Gofer returned invalid host socket for BidirectionalConnect; file %+v flags %+v: %v", e.dentry.file, flags, serr) return nil, serr } return c, nil diff --git a/pkg/sentry/fsimpl/gofer/special_file.go b/pkg/sentry/fsimpl/gofer/special_file.go index 71581736c..625400c0b 100644 --- a/pkg/sentry/fsimpl/gofer/special_file.go +++ b/pkg/sentry/fsimpl/gofer/special_file.go @@ -15,7 +15,6 @@ package gofer import ( - "sync" "sync/atomic" "syscall" @@ -25,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/p9" "gvisor.dev/gvisor/pkg/safemem" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" @@ -40,7 +40,7 @@ type specialFileFD struct { fileDescription // handle is used for file I/O. handle is immutable. - handle handle `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + handle handle `state:"nosave"` // isRegularFile is true if this FD represents a regular file which is only // possible when filesystemOptions.regularFilesUseSpecialFileFD is in @@ -54,12 +54,20 @@ type specialFileFD struct { // haveQueue is true if this file description represents a file for which // queue may send I/O readiness events. haveQueue is immutable. - haveQueue bool + haveQueue bool `state:"nosave"` queue waiter.Queue // If seekable is true, off is the file offset. off is protected by mu. mu sync.Mutex `state:"nosave"` off int64 + + // If haveBuf is non-zero, this FD represents a pipe, and buf contains data + // read from the pipe from previous calls to specialFileFD.savePipeData(). + // haveBuf and buf are protected by bufMu. haveBuf is accessed using atomic + // memory operations. + bufMu sync.Mutex `state:"nosave"` + haveBuf uint32 + buf []byte } func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, flags uint32) (*specialFileFD, error) { @@ -87,6 +95,9 @@ func newSpecialFileFD(h handle, mnt *vfs.Mount, d *dentry, locks *vfs.FileLocks, } return nil, err } + d.fs.syncMu.Lock() + d.fs.specialFileFDs[fd] = struct{}{} + d.fs.syncMu.Unlock() return fd, nil } @@ -161,26 +172,51 @@ func (fd *specialFileFD) PRead(ctx context.Context, dst usermem.IOSequence, offs return 0, syserror.EOPNOTSUPP } - // Going through dst.CopyOutFrom() holds MM locks around file operations of - // unknown duration. For regularFileFD, doing so is necessary to support - // mmap due to lock ordering; MM locks precede dentry.dataMu. That doesn't - // hold here since specialFileFD doesn't client-cache data. Just buffer the - // read instead. if d := fd.dentry(); d.cachedMetadataAuthoritative() { d.touchAtime(fd.vfsfd.Mount()) } + + bufN := int64(0) + if atomic.LoadUint32(&fd.haveBuf) != 0 { + var err error + fd.bufMu.Lock() + if len(fd.buf) != 0 { + var n int + n, err = dst.CopyOut(ctx, fd.buf) + dst = dst.DropFirst(n) + fd.buf = fd.buf[n:] + if len(fd.buf) == 0 { + atomic.StoreUint32(&fd.haveBuf, 0) + fd.buf = nil + } + bufN = int64(n) + if offset >= 0 { + offset += bufN + } + } + fd.bufMu.Unlock() + if err != nil { + return bufN, err + } + } + + // Going through dst.CopyOutFrom() would hold MM locks around file + // operations of unknown duration. For regularFileFD, doing so is necessary + // to support mmap due to lock ordering; MM locks precede dentry.dataMu. + // That doesn't hold here since specialFileFD doesn't client-cache data. + // Just buffer the read instead. buf := make([]byte, dst.NumBytes()) n, err := fd.handle.readToBlocksAt(ctx, safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf)), uint64(offset)) if err == syserror.EAGAIN { err = syserror.ErrWouldBlock } if n == 0 { - return 0, err + return bufN, err } if cp, cperr := dst.CopyOut(ctx, buf[:n]); cperr != nil { - return int64(cp), cperr + return bufN + int64(cp), cperr } - return int64(n), err + return bufN + int64(n), err } // Read implements vfs.FileDescriptionImpl.Read. @@ -217,16 +253,16 @@ func (fd *specialFileFD) pwrite(ctx context.Context, src usermem.IOSequence, off } d := fd.dentry() - // If the regular file fd was opened with O_APPEND, make sure the file size - // is updated. There is a possible race here if size is modified externally - // after metadata cache is updated. - if fd.isRegularFile && fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { - if err := d.updateFromGetattr(ctx); err != nil { - return 0, offset, err + if fd.isRegularFile { + // If the regular file fd was opened with O_APPEND, make sure the file + // size is updated. There is a possible race here if size is modified + // externally after metadata cache is updated. + if fd.vfsfd.StatusFlags()&linux.O_APPEND != 0 && !d.cachedMetadataAuthoritative() { + if err := d.updateFromGetattr(ctx); err != nil { + return 0, offset, err + } } - } - if fd.isRegularFile { // We need to hold the metadataMu *while* writing to a regular file. d.metadataMu.Lock() defer d.metadataMu.Unlock() @@ -306,13 +342,31 @@ func (fd *specialFileFD) Seek(ctx context.Context, offset int64, whence int32) ( // Sync implements vfs.FileDescriptionImpl.Sync. func (fd *specialFileFD) Sync(ctx context.Context) error { - // If we have a host FD, fsyncing it is likely to be faster than an fsync - // RPC. - if fd.handle.fd >= 0 { - ctx.UninterruptibleSleepStart(false) - err := syscall.Fsync(int(fd.handle.fd)) - ctx.UninterruptibleSleepFinish(false) - return err + return fd.sync(ctx, false /* forFilesystemSync */) +} + +func (fd *specialFileFD) sync(ctx context.Context, forFilesystemSync bool) error { + err := func() error { + // If we have a host FD, fsyncing it is likely to be faster than an fsync + // RPC. + if fd.handle.fd >= 0 { + ctx.UninterruptibleSleepStart(false) + err := syscall.Fsync(int(fd.handle.fd)) + ctx.UninterruptibleSleepFinish(false) + return err + } + return fd.handle.file.fsync(ctx) + }() + if err != nil { + if !forFilesystemSync { + return err + } + // Only return err if we can reasonably have expected sync to succeed + // (fd represents a regular file that was opened for writing). + if fd.isRegularFile && fd.vfsfd.IsWritable() { + return err + } + ctx.Debugf("gofer.specialFileFD.sync: syncing non-writable or non-regular-file FD failed: %v", err) } - return fd.handle.file.fsync(ctx) + return nil } diff --git a/pkg/sentry/fsimpl/gofer/time.go b/pkg/sentry/fsimpl/gofer/time.go index 7e825caae..9cbe805b9 100644 --- a/pkg/sentry/fsimpl/gofer/time.go +++ b/pkg/sentry/fsimpl/gofer/time.go @@ -17,7 +17,6 @@ package gofer import ( "sync/atomic" - "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/vfs" ) @@ -25,17 +24,6 @@ func dentryTimestampFromP9(s, ns uint64) int64 { return int64(s*1e9 + ns) } -func dentryTimestampFromStatx(ts linux.StatxTimestamp) int64 { - return ts.Sec*1e9 + int64(ts.Nsec) -} - -func statxTimestampFromDentry(ns int64) linux.StatxTimestamp { - return linux.StatxTimestamp{ - Sec: ns / 1e9, - Nsec: uint32(ns % 1e9), - } -} - // Preconditions: d.cachedMetadataAuthoritative() == true. func (d *dentry) touchAtime(mnt *vfs.Mount) { if mnt.Flags.NoATime || mnt.ReadOnly() { diff --git a/pkg/sentry/fsimpl/host/BUILD b/pkg/sentry/fsimpl/host/BUILD index 56bcf9bdb..4ae9d6d5e 100644 --- a/pkg/sentry/fsimpl/host/BUILD +++ b/pkg/sentry/fsimpl/host/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "inode_refs.go", package = "host", prefix = "inode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "inode", }, @@ -19,7 +19,7 @@ go_template_instance( out = "connected_endpoint_refs.go", package = "host", prefix = "ConnectedEndpoint", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "ConnectedEndpoint", }, @@ -33,7 +33,7 @@ go_library( "host.go", "inode_refs.go", "ioctl_unsafe.go", - "mmap.go", + "save_restore.go", "socket.go", "socket_iovec.go", "socket_unsafe.go", @@ -51,6 +51,7 @@ go_library( "//pkg/log", "//pkg/marshal/primitive", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs/fsutil", diff --git a/pkg/sentry/fsimpl/host/control.go b/pkg/sentry/fsimpl/host/control.go index 0135e4428..13ef48cb5 100644 --- a/pkg/sentry/fsimpl/host/control.go +++ b/pkg/sentry/fsimpl/host/control.go @@ -79,7 +79,7 @@ func fdsToFiles(ctx context.Context, fds []int) []*vfs.FileDescription { } // Create the file backed by hostFD. - file, err := ImportFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, false /* isTTY */) + file, err := NewFD(ctx, kernel.KernelFromContext(ctx).HostMount(), fd, &NewFDOptions{}) if err != nil { ctx.Warningf("Error creating file from host FD: %v", err) break diff --git a/pkg/sentry/fsimpl/host/host.go b/pkg/sentry/fsimpl/host/host.go index 698e913fe..39b902a3e 100644 --- a/pkg/sentry/fsimpl/host/host.go +++ b/pkg/sentry/fsimpl/host/host.go @@ -19,6 +19,7 @@ package host import ( "fmt" "math" + "sync/atomic" "syscall" "golang.org/x/sys/unix" @@ -40,34 +41,97 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) (*inode, error) { - // Determine if hostFD is seekable. If not, this syscall will return ESPIPE - // (see fs/read_write.c:llseek), e.g. for pipes, sockets, and some character - // devices. +// inode implements kernfs.Inode. +// +// +stateify savable +type inode struct { + kernfs.InodeNoStatFS + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink + kernfs.CachedMappable + kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid. + + locks vfs.FileLocks + + // When the reference count reaches zero, the host fd is closed. + inodeRefs + + // hostFD contains the host fd that this file was originally created from, + // which must be available at time of restore. + // + // This field is initialized at creation time and is immutable. + hostFD int + + // ino is an inode number unique within this filesystem. + // + // This field is initialized at creation time and is immutable. + ino uint64 + + // ftype is the file's type (a linux.S_IFMT mask). + // + // This field is initialized at creation time and is immutable. + ftype uint16 + + // mayBlock is true if hostFD is non-blocking, and operations on it may + // return EAGAIN or EWOULDBLOCK instead of blocking. + // + // This field is initialized at creation time and is immutable. + mayBlock bool + + // seekable is false if lseek(hostFD) returns ESPIPE. We assume that file + // offsets are meaningful iff seekable is true. + // + // This field is initialized at creation time and is immutable. + seekable bool + + // isTTY is true if this file represents a TTY. + // + // This field is initialized at creation time and is immutable. + isTTY bool + + // savable is true if hostFD may be saved/restored by its numeric value. + // + // This field is initialized at creation time and is immutable. + savable bool + + // Event queue for blocking operations. + queue waiter.Queue + + // If haveBuf is non-zero, hostFD represents a pipe, and buf contains data + // read from the pipe from previous calls to inode.beforeSave(). haveBuf + // and buf are protected by bufMu. haveBuf is accessed using atomic memory + // operations. + bufMu sync.Mutex `state:"nosave"` + haveBuf uint32 + buf []byte +} + +func newInode(ctx context.Context, fs *filesystem, hostFD int, savable bool, fileType linux.FileMode, isTTY bool) (*inode, error) { + // Determine if hostFD is seekable. _, err := unix.Seek(hostFD, 0, linux.SEEK_CUR) seekable := err != syserror.ESPIPE + // We expect regular files to be seekable, as this is required for them to + // be memory-mappable. + if !seekable && fileType == syscall.S_IFREG { + ctx.Infof("host.newInode: host FD %d is a non-seekable regular file", hostFD) + return nil, syserror.ESPIPE + } i := &inode{ - hostFD: hostFD, - ino: fs.NextIno(), - isTTY: isTTY, - wouldBlock: wouldBlock(uint32(fileType)), - seekable: seekable, - // NOTE(b/38213152): Technically, some obscure char devices can be memory - // mapped, but we only allow regular files. - canMap: fileType == linux.S_IFREG, - } - i.pf.inode = i + hostFD: hostFD, + ino: fs.NextIno(), + ftype: uint16(fileType), + mayBlock: fileType != syscall.S_IFREG && fileType != syscall.S_IFDIR, + seekable: seekable, + isTTY: isTTY, + savable: savable, + } + i.CachedMappable.Init(hostFD) i.EnableLeakCheck() - // Non-seekable files can't be memory mapped, assert this. - if !i.seekable && i.canMap { - panic("files that can return EWOULDBLOCK (sockets, pipes, etc.) cannot be memory mapped") - } - - // If the hostFD would block, we must set it to non-blocking and handle - // blocking behavior in the sentry. - if i.wouldBlock { + // If the hostFD can return EWOULDBLOCK when set to non-blocking, do so and + // handle blocking behavior in the sentry. + if i.mayBlock { if err := syscall.SetNonblock(i.hostFD, true); err != nil { return nil, err } @@ -80,6 +144,11 @@ func newInode(fs *filesystem, hostFD int, fileType linux.FileMode, isTTY bool) ( // NewFDOptions contains options to NewFD. type NewFDOptions struct { + // If Savable is true, the host file descriptor may be saved/restored by + // numeric value; the sandbox API requires a corresponding host FD with the + // same numeric value to be provieded at time of restore. + Savable bool + // If IsTTY is true, the file descriptor is a TTY. IsTTY bool @@ -114,7 +183,7 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) } d := &kernfs.Dentry{} - i, err := newInode(fs, hostFD, linux.FileMode(s.Mode).FileType(), opts.IsTTY) + i, err := newInode(ctx, fs, hostFD, opts.Savable, linux.FileMode(s.Mode).FileType(), opts.IsTTY) if err != nil { return nil, err } @@ -132,7 +201,8 @@ func NewFD(ctx context.Context, mnt *vfs.Mount, hostFD int, opts *NewFDOptions) // ImportFD sets up and returns a vfs.FileDescription from a donated fd. func ImportFD(ctx context.Context, mnt *vfs.Mount, hostFD int, isTTY bool) (*vfs.FileDescription, error) { return NewFD(ctx, mnt, hostFD, &NewFDOptions{ - IsTTY: isTTY, + Savable: true, + IsTTY: isTTY, }) } @@ -191,68 +261,6 @@ func (fs *filesystem) PrependPath(ctx context.Context, vfsroot, vd vfs.VirtualDe return vfs.PrependPathSyntheticError{} } -// inode implements kernfs.Inode. -// -// +stateify savable -type inode struct { - kernfs.InodeNoStatFS - kernfs.InodeNotDirectory - kernfs.InodeNotSymlink - kernfs.InodeTemporary // This holds no meaning as this inode can't be Looked up and is always valid. - - locks vfs.FileLocks - - // When the reference count reaches zero, the host fd is closed. - inodeRefs - - // hostFD contains the host fd that this file was originally created from, - // which must be available at time of restore. - // - // This field is initialized at creation time and is immutable. - hostFD int - - // ino is an inode number unique within this filesystem. - // - // This field is initialized at creation time and is immutable. - ino uint64 - - // isTTY is true if this file represents a TTY. - // - // This field is initialized at creation time and is immutable. - isTTY bool - - // seekable is false if the host fd points to a file representing a stream, - // e.g. a socket or a pipe. Such files are not seekable and can return - // EWOULDBLOCK for I/O operations. - // - // This field is initialized at creation time and is immutable. - seekable bool - - // wouldBlock is true if the host FD would return EWOULDBLOCK for - // operations that would block. - // - // This field is initialized at creation time and is immutable. - wouldBlock bool - - // Event queue for blocking operations. - queue waiter.Queue - - // canMap specifies whether we allow the file to be memory mapped. - // - // This field is initialized at creation time and is immutable. - canMap bool - - // mapsMu protects mappings. - mapsMu sync.Mutex `state:"nosave"` - - // If canMap is true, mappings tracks mappings of hostFD into - // memmap.MappingSpaces. - mappings memmap.MappingSet - - // pf implements platform.File for mappings of hostFD. - pf inodePlatformFile -} - // CheckPermissions implements kernfs.Inode.CheckPermissions. func (i *inode) CheckPermissions(ctx context.Context, creds *auth.Credentials, ats vfs.AccessTypes) error { var s syscall.Stat_t @@ -422,14 +430,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre oldpgend, _ := usermem.PageRoundUp(oldSize) newpgend, _ := usermem.PageRoundUp(s.Size) if oldpgend != newpgend { - i.mapsMu.Lock() - i.mappings.Invalidate(memmap.MappableRange{newpgend, oldpgend}, memmap.InvalidateOpts{ - // Compare Linux's mm/truncate.c:truncate_setsize() => - // truncate_pagecache() => - // mm/memory.c:unmap_mapping_range(evencows=1). - InvalidatePrivate: true, - }) - i.mapsMu.Unlock() + i.CachedMappable.InvalidateRange(memmap.MappableRange{newpgend, oldpgend}) } } } @@ -448,7 +449,7 @@ func (i *inode) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Cre // DecRef implements kernfs.Inode.DecRef. func (i *inode) DecRef(ctx context.Context) { i.inodeRefs.DecRef(func() { - if i.wouldBlock { + if i.mayBlock { fdnotifier.RemoveFD(int32(i.hostFD)) } if err := unix.Close(i.hostFD); err != nil { @@ -567,6 +568,13 @@ func (f *fileDescription) Allocate(ctx context.Context, mode, offset, length uin // PRead implements vfs.FileDescriptionImpl.PRead. func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, syserror.EOPNOTSUPP + } + i := f.inode if !i.seekable { return 0, syserror.ESPIPE @@ -577,19 +585,31 @@ func (f *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, off // Read implements vfs.FileDescriptionImpl.Read. func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // Check that flags are supported. + // + // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. + if opts.Flags&^linux.RWF_HIPRI != 0 { + return 0, syserror.EOPNOTSUPP + } + i := f.inode if !i.seekable { + bufN, err := i.readFromBuf(ctx, &dst) + if err != nil { + return bufN, err + } n, err := readFromHostFD(ctx, i.hostFD, dst, -1, opts.Flags) + total := bufN + n if isBlockError(err) { // If we got any data at all, return it as a "completed" partial read // rather than retrying until complete. - if n != 0 { + if total != 0 { err = nil } else { err = syserror.ErrWouldBlock } } - return n, err + return total, err } f.offsetMu.Lock() @@ -599,13 +619,26 @@ func (f *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts return n, err } -func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) { - // Check that flags are supported. - // - // TODO(gvisor.dev/issue/2601): Support select preadv2 flags. - if flags&^linux.RWF_HIPRI != 0 { - return 0, syserror.EOPNOTSUPP +func (i *inode) readFromBuf(ctx context.Context, dst *usermem.IOSequence) (int64, error) { + if atomic.LoadUint32(&i.haveBuf) == 0 { + return 0, nil + } + i.bufMu.Lock() + defer i.bufMu.Unlock() + if len(i.buf) == 0 { + return 0, nil } + n, err := dst.CopyOut(ctx, i.buf) + *dst = dst.DropFirst(n) + i.buf = i.buf[n:] + if len(i.buf) == 0 { + atomic.StoreUint32(&i.haveBuf, 0) + i.buf = nil + } + return int64(n), err +} + +func readFromHostFD(ctx context.Context, hostFD int, dst usermem.IOSequence, offset int64, flags uint32) (int64, error) { reader := hostfd.GetReadWriterAt(int32(hostFD), offset, flags) n, err := dst.CopyOutFrom(ctx, reader) hostfd.PutReadWriterAt(reader) @@ -735,31 +768,37 @@ func (f *fileDescription) Seek(_ context.Context, offset int64, whence int32) (i } // Sync implements vfs.FileDescriptionImpl.Sync. -func (f *fileDescription) Sync(context.Context) error { +func (f *fileDescription) Sync(ctx context.Context) error { // TODO(gvisor.dev/issue/1897): Currently, we always sync everything. return unix.Fsync(f.inode.hostFD) } // ConfigureMMap implements vfs.FileDescriptionImpl.ConfigureMMap. func (f *fileDescription) ConfigureMMap(_ context.Context, opts *memmap.MMapOpts) error { - if !f.inode.canMap { + // NOTE(b/38213152): Technically, some obscure char devices can be memory + // mapped, but we only allow regular files. + if f.inode.ftype != syscall.S_IFREG { return syserror.ENODEV } i := f.inode - i.pf.fileMapperInitOnce.Do(i.pf.fileMapper.Init) + i.CachedMappable.InitFileMapperOnce() return vfs.GenericConfigureMMap(&f.vfsfd, i, opts) } // EventRegister implements waiter.Waitable.EventRegister. func (f *fileDescription) EventRegister(e *waiter.Entry, mask waiter.EventMask) { f.inode.queue.EventRegister(e, mask) - fdnotifier.UpdateFD(int32(f.inode.hostFD)) + if f.inode.mayBlock { + fdnotifier.UpdateFD(int32(f.inode.hostFD)) + } } // EventUnregister implements waiter.Waitable.EventUnregister. func (f *fileDescription) EventUnregister(e *waiter.Entry) { f.inode.queue.EventUnregister(e) - fdnotifier.UpdateFD(int32(f.inode.hostFD)) + if f.inode.mayBlock { + fdnotifier.UpdateFD(int32(f.inode.hostFD)) + } } // Readiness uses the poll() syscall to check the status of the underlying FD. diff --git a/pkg/sentry/fsimpl/host/save_restore.go b/pkg/sentry/fsimpl/host/save_restore.go new file mode 100644 index 000000000..8800652a9 --- /dev/null +++ b/pkg/sentry/fsimpl/host/save_restore.go @@ -0,0 +1,70 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package host + +import ( + "fmt" + "io" + "sync/atomic" + "syscall" + + "gvisor.dev/gvisor/pkg/fdnotifier" + "gvisor.dev/gvisor/pkg/safemem" + "gvisor.dev/gvisor/pkg/sentry/hostfd" + "gvisor.dev/gvisor/pkg/usermem" +) + +// beforeSave is invoked by stateify. +func (i *inode) beforeSave() { + if !i.savable { + panic("host.inode is not savable") + } + if i.ftype == syscall.S_IFIFO { + // If this pipe FD is readable, drain it so that bytes in the pipe can + // be read after restore. (This is a legacy VFS1 feature.) We don't + // know if the pipe FD is readable, so just try reading and tolerate + // EBADF from the read. + i.bufMu.Lock() + defer i.bufMu.Unlock() + var buf [usermem.PageSize]byte + for { + n, err := hostfd.Preadv2(int32(i.hostFD), safemem.BlockSeqOf(safemem.BlockFromSafeSlice(buf[:])), -1 /* offset */, 0 /* flags */) + if n != 0 { + i.buf = append(i.buf, buf[:n]...) + } + if err != nil { + if err == io.EOF || err == syscall.EAGAIN || err == syscall.EBADF { + break + } + panic(fmt.Errorf("host.inode.beforeSave: buffering from pipe failed: %v", err)) + } + } + if len(i.buf) != 0 { + atomic.StoreUint32(&i.haveBuf, 1) + } + } +} + +// afterLoad is invoked by stateify. +func (i *inode) afterLoad() { + if i.mayBlock { + if err := syscall.SetNonblock(i.hostFD, true); err != nil { + panic(fmt.Sprintf("host.inode.afterLoad: failed to set host FD %d non-blocking: %v", i.hostFD, err)) + } + if err := fdnotifier.AddFD(int32(i.hostFD), &i.queue); err != nil { + panic(fmt.Sprintf("host.inode.afterLoad: fdnotifier.AddFD(%d) failed: %v", i.hostFD, err)) + } + } +} diff --git a/pkg/sentry/fsimpl/host/util.go b/pkg/sentry/fsimpl/host/util.go index 412bdb2eb..b2f43a119 100644 --- a/pkg/sentry/fsimpl/host/util.go +++ b/pkg/sentry/fsimpl/host/util.go @@ -43,12 +43,6 @@ func timespecToStatxTimestamp(ts unix.Timespec) linux.StatxTimestamp { return linux.StatxTimestamp{Sec: int64(ts.Sec), Nsec: uint32(ts.Nsec)} } -// wouldBlock returns true for file types that can return EWOULDBLOCK -// for blocking operations, e.g. pipes, character devices, and sockets. -func wouldBlock(fileType uint32) bool { - return fileType == syscall.S_IFIFO || fileType == syscall.S_IFCHR || fileType == syscall.S_IFSOCK -} - // isBlockError checks if an error is EAGAIN or EWOULDBLOCK. // If so, they can be transformed into syserror.ErrWouldBlock. func isBlockError(err error) bool { diff --git a/pkg/sentry/fsimpl/kernfs/BUILD b/pkg/sentry/fsimpl/kernfs/BUILD index 858cc24ce..6dbc7e34d 100644 --- a/pkg/sentry/fsimpl/kernfs/BUILD +++ b/pkg/sentry/fsimpl/kernfs/BUILD @@ -4,6 +4,18 @@ load("//tools/go_generics:defs.bzl", "go_template_instance") licenses(["notice"]) go_template_instance( + name = "dentry_list", + out = "dentry_list.go", + package = "kernfs", + prefix = "dentry", + template = "//pkg/ilist:generic_list", + types = { + "Element": "*Dentry", + "Linker": "*Dentry", + }, +) + +go_template_instance( name = "fstree", out = "fstree.go", package = "kernfs", @@ -27,22 +39,11 @@ go_template_instance( ) go_template_instance( - name = "dentry_refs", - out = "dentry_refs.go", - package = "kernfs", - prefix = "Dentry", - template = "//pkg/refs_vfs2:refs_template", - types = { - "T": "Dentry", - }, -) - -go_template_instance( name = "static_directory_refs", out = "static_directory_refs.go", package = "kernfs", prefix = "StaticDirectory", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "StaticDirectory", }, @@ -53,7 +54,7 @@ go_template_instance( out = "dir_refs.go", package = "kernfs_test", prefix = "dir", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "dir", }, @@ -64,7 +65,7 @@ go_template_instance( out = "readonly_dir_refs.go", package = "kernfs_test", prefix = "readonlyDir", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "readonlyDir", }, @@ -75,7 +76,7 @@ go_template_instance( out = "synthetic_directory_refs.go", package = "kernfs", prefix = "syntheticDirectory", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "syntheticDirectory", }, @@ -84,13 +85,15 @@ go_template_instance( go_library( name = "kernfs", srcs = [ - "dentry_refs.go", + "dentry_list.go", "dynamic_bytes_file.go", "fd_impl_util.go", "filesystem.go", "fstree.go", "inode_impl_util.go", "kernfs.go", + "mmap_util.go", + "save_restore.go", "slot_list.go", "static_directory_refs.go", "symlink.go", @@ -104,8 +107,12 @@ go_library( "//pkg/fspath", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", + "//pkg/safemem", + "//pkg/sentry/fs/fsutil", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", + "//pkg/sentry/kernel/time", "//pkg/sentry/memmap", "//pkg/sentry/socket/unix/transport", "//pkg/sentry/vfs", @@ -129,6 +136,7 @@ go_test( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sentry/contexttest", "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go index b929118b1..485504995 100644 --- a/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go +++ b/pkg/sentry/fsimpl/kernfs/dynamic_bytes_file.go @@ -47,11 +47,11 @@ type DynamicBytesFile struct { var _ Inode = (*DynamicBytesFile)(nil) // Init initializes a dynamic bytes file. -func (f *DynamicBytesFile) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) { +func (f *DynamicBytesFile) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } - f.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm) + f.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm) f.data = data } diff --git a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go index abf1905d6..f8dae22f8 100644 --- a/pkg/sentry/fsimpl/kernfs/fd_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/fd_impl_util.go @@ -145,8 +145,12 @@ func (fd *GenericDirectoryFD) filesystem() *vfs.Filesystem { return fd.vfsfd.VirtualDentry().Mount().Filesystem() } +func (fd *GenericDirectoryFD) dentry() *Dentry { + return fd.vfsfd.Dentry().Impl().(*Dentry) +} + func (fd *GenericDirectoryFD) inode() Inode { - return fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode + return fd.dentry().inode } // IterDirents implements vfs.FileDescriptionImpl.IterDirents. IterDirents holds @@ -176,8 +180,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent // Handle "..". if fd.off == 1 { - vfsd := fd.vfsfd.VirtualDentry().Dentry() - parentInode := genericParentOrSelf(vfsd.Impl().(*Dentry)).inode + parentInode := genericParentOrSelf(fd.dentry()).inode stat, err := parentInode.Stat(ctx, fd.filesystem(), opts) if err != nil { return err @@ -219,7 +222,7 @@ func (fd *GenericDirectoryFD) IterDirents(ctx context.Context, cb vfs.IterDirent var err error relOffset := fd.off - int64(len(fd.children.set)) - 2 - fd.off, err = fd.inode().IterDirents(ctx, cb, fd.off, relOffset) + fd.off, err = fd.inode().IterDirents(ctx, fd.vfsfd.Mount(), cb, fd.off, relOffset) return err } @@ -265,8 +268,7 @@ func (fd *GenericDirectoryFD) Stat(ctx context.Context, opts vfs.StatOptions) (l // SetStat implements vfs.FileDescriptionImpl.SetStat. func (fd *GenericDirectoryFD) SetStat(ctx context.Context, opts vfs.SetStatOptions) error { creds := auth.CredentialsFromContext(ctx) - inode := fd.vfsfd.VirtualDentry().Dentry().Impl().(*Dentry).inode - return inode.SetStat(ctx, fd.filesystem(), creds, opts) + return fd.inode().SetStat(ctx, fd.filesystem(), creds, opts) } // Allocate implements vfs.FileDescriptionImpl.Allocate. diff --git a/pkg/sentry/fsimpl/kernfs/filesystem.go b/pkg/sentry/fsimpl/kernfs/filesystem.go index 6426a55f6..e77523f22 100644 --- a/pkg/sentry/fsimpl/kernfs/filesystem.go +++ b/pkg/sentry/fsimpl/kernfs/filesystem.go @@ -207,24 +207,23 @@ func (fs *Filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.Resolving // Preconditions: // * Filesystem.mu must be locked for at least reading. // * isDir(parentInode) == true. -func checkCreateLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *Dentry) (string, error) { - if err := parent.inode.CheckPermissions(ctx, rp.Credentials(), vfs.MayWrite|vfs.MayExec); err != nil { - return "", err +func checkCreateLocked(ctx context.Context, creds *auth.Credentials, name string, parent *Dentry) error { + if err := parent.inode.CheckPermissions(ctx, creds, vfs.MayWrite|vfs.MayExec); err != nil { + return err } - pc := rp.Component() - if pc == "." || pc == ".." { - return "", syserror.EEXIST + if name == "." || name == ".." { + return syserror.EEXIST } - if len(pc) > linux.NAME_MAX { - return "", syserror.ENAMETOOLONG + if len(name) > linux.NAME_MAX { + return syserror.ENAMETOOLONG } - if _, ok := parent.children[pc]; ok { - return "", syserror.EEXIST + if _, ok := parent.children[name]; ok { + return syserror.EEXIST } if parent.VFSDentry().IsDead() { - return "", syserror.ENOENT + return syserror.ENOENT } - return pc, nil + return nil } // checkDeleteLocked checks that the file represented by vfsd may be deleted. @@ -245,7 +244,41 @@ func checkDeleteLocked(ctx context.Context, rp *vfs.ResolvingPath, d *Dentry) er } // Release implements vfs.FilesystemImpl.Release. -func (fs *Filesystem) Release(context.Context) { +func (fs *Filesystem) Release(ctx context.Context) { + root := fs.root + if root == nil { + return + } + fs.mu.Lock() + root.releaseKeptDentriesLocked(ctx) + for fs.cachedDentriesLen != 0 { + fs.evictCachedDentryLocked(ctx) + } + fs.mu.Unlock() + // Drop ref acquired in Dentry.InitRoot(). + root.DecRef(ctx) +} + +// releaseKeptDentriesLocked recursively drops all dentry references created by +// Lookup when Dentry.inode.Keep() is true. +// +// Precondition: Filesystem.mu is held. +func (d *Dentry) releaseKeptDentriesLocked(ctx context.Context) { + if d.inode.Keep() && d != d.fs.root { + d.decRefLocked(ctx) + } + + if d.isDir() { + var children []*Dentry + d.dirMu.Lock() + for _, child := range d.children { + children = append(children, child) + } + d.dirMu.Unlock() + for _, child := range children { + child.releaseKeptDentriesLocked(ctx) + } + } } // Sync implements vfs.FilesystemImpl.Sync. @@ -318,10 +351,13 @@ func (fs *Filesystem) LinkAt(ctx context.Context, rp *vfs.ResolvingPath, vd vfs. parent.dirMu.Lock() defer parent.dirMu.Unlock() - pc, err := checkCreateLocked(ctx, rp, parent) - if err != nil { + pc := rp.Component() + if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil { return err } + if rp.MustBeDir() { + return syserror.ENOENT + } if rp.Mount() != vd.Mount() { return syserror.EXDEV } @@ -360,8 +396,8 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v parent.dirMu.Lock() defer parent.dirMu.Unlock() - pc, err := checkCreateLocked(ctx, rp, parent) - if err != nil { + pc := rp.Component() + if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil { return err } if err := rp.Mount().CheckBeginWrite(); err != nil { @@ -373,7 +409,7 @@ func (fs *Filesystem) MkdirAt(ctx context.Context, rp *vfs.ResolvingPath, opts v if !opts.ForSyntheticMountpoint || err == syserror.EEXIST { return err } - childI = newSyntheticDirectory(rp.Credentials(), opts.Mode) + childI = newSyntheticDirectory(ctx, rp.Credentials(), opts.Mode) } var child Dentry child.Init(fs, childI) @@ -396,10 +432,13 @@ func (fs *Filesystem) MknodAt(ctx context.Context, rp *vfs.ResolvingPath, opts v parent.dirMu.Lock() defer parent.dirMu.Unlock() - pc, err := checkCreateLocked(ctx, rp, parent) - if err != nil { + pc := rp.Component() + if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil { return err } + if rp.MustBeDir() { + return syserror.ENOENT + } if err := rp.Mount().CheckBeginWrite(); err != nil { return err } @@ -517,9 +556,6 @@ afterTrailingSymlink: } var child Dentry child.Init(fs, childI) - // FIXME(gvisor.dev/issue/1193): Race between checking existence with - // fs.stepExistingLocked and parent.insertChild. If possible, we should hold - // dirMu from one to the other. parent.insertChild(pc, &child) // Open may block so we need to unlock fs.mu. IncRef child to prevent // its destruction while fs.mu is unlocked. @@ -626,8 +662,8 @@ func (fs *Filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // Can we create the dst dentry? var dst *Dentry - pc, err := checkCreateLocked(ctx, rp, dstDir) - switch err { + pc := rp.Component() + switch err := checkCreateLocked(ctx, rp.Credentials(), pc, dstDir); err { case nil: // Ok, continue with rename as replacement. case syserror.EEXIST: @@ -791,10 +827,13 @@ func (fs *Filesystem) SymlinkAt(ctx context.Context, rp *vfs.ResolvingPath, targ parent.dirMu.Lock() defer parent.dirMu.Unlock() - pc, err := checkCreateLocked(ctx, rp, parent) - if err != nil { + pc := rp.Component() + if err := checkCreateLocked(ctx, rp.Credentials(), pc, parent); err != nil { return err } + if rp.MustBeDir() { + return syserror.ENOENT + } if err := rp.Mount().CheckBeginWrite(); err != nil { return err } diff --git a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go index 122b10591..d83c17f83 100644 --- a/pkg/sentry/fsimpl/kernfs/inode_impl_util.go +++ b/pkg/sentry/fsimpl/kernfs/inode_impl_util.go @@ -21,9 +21,11 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" + ktime "gvisor.dev/gvisor/pkg/sentry/kernel/time" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" ) // InodeNoopRefCount partially implements the Inode interface, specifically the @@ -143,7 +145,7 @@ func (InodeNotDirectory) Lookup(ctx context.Context, name string) (Inode, error) } // IterDirents implements Inode.IterDirents. -func (InodeNotDirectory) IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { +func (InodeNotDirectory) IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { panic("IterDirents called on non-directory inode") } @@ -172,17 +174,23 @@ func (InodeNotSymlink) Getlink(context.Context, *vfs.Mount) (vfs.VirtualDentry, // // +stateify savable type InodeAttrs struct { - devMajor uint32 - devMinor uint32 - ino uint64 - mode uint32 - uid uint32 - gid uint32 - nlink uint32 + devMajor uint32 + devMinor uint32 + ino uint64 + mode uint32 + uid uint32 + gid uint32 + nlink uint32 + blockSize uint32 + + // Timestamps, all nsecs from the Unix epoch. + atime int64 + mtime int64 + ctime int64 } // Init initializes this InodeAttrs. -func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) { +func (a *InodeAttrs) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, mode linux.FileMode) { if mode.FileType() == 0 { panic(fmt.Sprintf("No file type specified in 'mode' for InodeAttrs.Init(): mode=0%o", mode)) } @@ -198,6 +206,11 @@ func (a *InodeAttrs) Init(creds *auth.Credentials, devMajor, devMinor uint32, in atomic.StoreUint32(&a.uid, uint32(creds.EffectiveKUID)) atomic.StoreUint32(&a.gid, uint32(creds.EffectiveKGID)) atomic.StoreUint32(&a.nlink, nlink) + atomic.StoreUint32(&a.blockSize, usermem.PageSize) + now := ktime.NowFromContext(ctx).Nanoseconds() + atomic.StoreInt64(&a.atime, now) + atomic.StoreInt64(&a.mtime, now) + atomic.StoreInt64(&a.ctime, now) } // DevMajor returns the device major number. @@ -220,12 +233,33 @@ func (a *InodeAttrs) Mode() linux.FileMode { return linux.FileMode(atomic.LoadUint32(&a.mode)) } +// TouchAtime updates a.atime to the current time. +func (a *InodeAttrs) TouchAtime(ctx context.Context, mnt *vfs.Mount) { + if mnt.Flags.NoATime || mnt.ReadOnly() { + return + } + if err := mnt.CheckBeginWrite(); err != nil { + return + } + atomic.StoreInt64(&a.atime, ktime.NowFromContext(ctx).Nanoseconds()) + mnt.EndWrite() +} + +// TouchCMtime updates a.{c/m}time to the current time. The caller should +// synchronize calls to this so that ctime and mtime are updated to the same +// value. +func (a *InodeAttrs) TouchCMtime(ctx context.Context) { + now := ktime.NowFromContext(ctx).Nanoseconds() + atomic.StoreInt64(&a.mtime, now) + atomic.StoreInt64(&a.ctime, now) +} + // Stat partially implements Inode.Stat. Note that this function doesn't provide // all the stat fields, and the embedder should consider extending the result // with filesystem-specific fields. func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (linux.Statx, error) { var stat linux.Statx - stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK + stat.Mask = linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO | linux.STATX_NLINK | linux.STATX_ATIME | linux.STATX_MTIME | linux.STATX_CTIME stat.DevMajor = a.devMajor stat.DevMinor = a.devMinor stat.Ino = atomic.LoadUint64(&a.ino) @@ -233,21 +267,15 @@ func (a *InodeAttrs) Stat(context.Context, *vfs.Filesystem, vfs.StatOptions) (li stat.UID = atomic.LoadUint32(&a.uid) stat.GID = atomic.LoadUint32(&a.gid) stat.Nlink = atomic.LoadUint32(&a.nlink) - - // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. - + stat.Blksize = atomic.LoadUint32(&a.blockSize) + stat.Atime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.atime)) + stat.Mtime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.mtime)) + stat.Ctime = linux.NsecToStatxTimestamp(atomic.LoadInt64(&a.ctime)) return stat, nil } // SetStat implements Inode.SetStat. func (a *InodeAttrs) SetStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { - return a.SetInodeStat(ctx, fs, creds, opts) -} - -// SetInodeStat sets the corresponding attributes from opts to InodeAttrs. -// This function can be used by other kernfs-based filesystem implementation to -// sets the unexported attributes into InodeAttrs. -func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds *auth.Credentials, opts vfs.SetStatOptions) error { if opts.Stat.Mask == 0 { return nil } @@ -256,9 +284,7 @@ func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds // inode numbers are immutable after node creation. Setting the size is often // allowed by kernfs files but does not do anything. If some other behavior is // needed, the embedder should consider extending SetStat. - // - // TODO(gvisor.dev/issue/1193): Implement other stat fields like timestamps. - if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_SIZE) != 0 { + if opts.Stat.Mask&^(linux.STATX_MODE|linux.STATX_UID|linux.STATX_GID|linux.STATX_ATIME|linux.STATX_MTIME|linux.STATX_SIZE) != 0 { return syserror.EPERM } if opts.Stat.Mask&linux.STATX_SIZE != 0 && a.Mode().IsDir() { @@ -286,6 +312,20 @@ func (a *InodeAttrs) SetInodeStat(ctx context.Context, fs *vfs.Filesystem, creds atomic.StoreUint32(&a.gid, stat.GID) } + now := ktime.NowFromContext(ctx).Nanoseconds() + if stat.Mask&linux.STATX_ATIME != 0 { + if stat.Atime.Nsec == linux.UTIME_NOW { + stat.Atime = linux.NsecToStatxTimestamp(now) + } + atomic.StoreInt64(&a.atime, stat.Atime.ToNsec()) + } + if stat.Mask&linux.STATX_MTIME != 0 { + if stat.Mtime.Nsec == linux.UTIME_NOW { + stat.Mtime = linux.NsecToStatxTimestamp(now) + } + atomic.StoreInt64(&a.mtime, stat.Mtime.ToNsec()) + } + return nil } @@ -421,7 +461,7 @@ func (o *OrderedChildren) Lookup(ctx context.Context, name string) (Inode, error } // IterDirents implements Inode.IterDirents. -func (o *OrderedChildren) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { +func (o *OrderedChildren) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { // All entries from OrderedChildren have already been handled in // GenericDirectoryFD.IterDirents. return offset, nil @@ -528,13 +568,6 @@ func (o *OrderedChildren) RmDir(ctx context.Context, name string, child Inode) e return o.Unlink(ctx, name, child) } -// +stateify savable -type renameAcrossDifferentImplementationsError struct{} - -func (renameAcrossDifferentImplementationsError) Error() string { - return "rename across inodes with different implementations" -} - // Rename implements Inode.Rename. // // Precondition: Rename may only be called across two directory inodes with @@ -545,13 +578,18 @@ func (renameAcrossDifferentImplementationsError) Error() string { // // Postcondition: reference on any replaced dentry transferred to caller. func (o *OrderedChildren) Rename(ctx context.Context, oldname, newname string, child, dstDir Inode) error { + if !o.writable { + return syserror.EPERM + } + dst, ok := dstDir.(interface{}).(*OrderedChildren) if !ok { - return renameAcrossDifferentImplementationsError{} + return syserror.EXDEV } - if !o.writable || !dst.writable { + if !dst.writable { return syserror.EPERM } + // Note: There's a potential deadlock below if concurrent calls to Rename // refer to the same src and dst directories in reverse. We avoid any // ordering issues because the caller is required to serialize concurrent @@ -619,9 +657,9 @@ type StaticDirectory struct { var _ Inode = (*StaticDirectory)(nil) // NewStaticDir creates a new static directory and returns its dentry. -func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode { +func NewStaticDir(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, children map[string]Inode, fdOpts GenericDirectoryFDOptions) Inode { inode := &StaticDirectory{} - inode.Init(creds, devMajor, devMinor, ino, perm, fdOpts) + inode.Init(ctx, creds, devMajor, devMinor, ino, perm, fdOpts) inode.EnableLeakCheck() inode.OrderedChildren.Init(OrderedChildrenOptions{}) @@ -632,12 +670,12 @@ func NewStaticDir(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64 } // Init initializes StaticDirectory. -func (s *StaticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) { +func (s *StaticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode, fdOpts GenericDirectoryFDOptions) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } s.fdOpts = fdOpts - s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeDirectory|perm) + s.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeDirectory|perm) } // Open implements Inode.Open. diff --git a/pkg/sentry/fsimpl/kernfs/kernfs.go b/pkg/sentry/fsimpl/kernfs/kernfs.go index 606081e68..abb477c7d 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs.go @@ -61,6 +61,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sync" @@ -107,6 +108,23 @@ type Filesystem struct { // nextInoMinusOne is used to to allocate inode numbers on this // filesystem. Must be accessed by atomic operations. nextInoMinusOne uint64 + + // cachedDentries contains all dentries with 0 references. (Due to race + // conditions, it may also contain dentries with non-zero references.) + // cachedDentriesLen is the number of dentries in cachedDentries. These + // fields are protected by mu. + cachedDentries dentryList + cachedDentriesLen uint64 + + // MaxCachedDentries is the maximum size of cachedDentries. If not set, + // defaults to 0 and kernfs does not cache any dentries. This is immutable. + MaxCachedDentries uint64 + + // root is the root dentry of this filesystem. Note that root may be nil for + // filesystems on a disconnected mount without a root (e.g. pipefs, sockfs, + // hostfs). Filesystem holds an extra reference on root to prevent it from + // being destroyed prematurely. This is immutable. + root *Dentry } // deferDecRef defers dropping a dentry ref until the next call to @@ -165,7 +183,12 @@ const ( // +stateify savable type Dentry struct { vfsd vfs.Dentry - DentryRefs + + // refs is the reference count. When refs reaches 0, the dentry may be + // added to the cache or destroyed. If refs == -1, the dentry has already + // been destroyed. refs are allowed to go to 0 and increase again. refs is + // accessed using atomic memory operations. + refs int64 // fs is the owning filesystem. fs is immutable. fs *Filesystem @@ -177,6 +200,12 @@ type Dentry struct { parent *Dentry name string + // If cached is true, dentryEntry links dentry into + // Filesystem.cachedDentries. cached and dentryEntry are protected by + // Filesystem.mu. + cached bool + dentryEntry + // dirMu protects children and the names of child Dentries. // // Note that holding fs.mu for writing is not sufficient; @@ -188,6 +217,201 @@ type Dentry struct { inode Inode } +// IncRef implements vfs.DentryImpl.IncRef. +func (d *Dentry) IncRef() { + // d.refs may be 0 if d.fs.mu is locked, which serializes against + // d.cacheLocked(). + r := atomic.AddInt64(&d.refs, 1) + refsvfs2.LogIncRef(d, r) +} + +// TryIncRef implements vfs.DentryImpl.TryIncRef. +func (d *Dentry) TryIncRef() bool { + for { + r := atomic.LoadInt64(&d.refs) + if r <= 0 { + return false + } + if atomic.CompareAndSwapInt64(&d.refs, r, r+1) { + refsvfs2.LogTryIncRef(d, r+1) + return true + } + } +} + +// DecRef implements vfs.DentryImpl.DecRef. +func (d *Dentry) DecRef(ctx context.Context) { + r := atomic.AddInt64(&d.refs, -1) + refsvfs2.LogDecRef(d, r) + if r == 0 { + d.fs.mu.Lock() + d.cacheLocked(ctx) + d.fs.mu.Unlock() + } else if r < 0 { + panic("kernfs.Dentry.DecRef() called without holding a reference") + } +} + +func (d *Dentry) decRefLocked(ctx context.Context) { + r := atomic.AddInt64(&d.refs, -1) + refsvfs2.LogDecRef(d, r) + if r == 0 { + d.cacheLocked(ctx) + } else if r < 0 { + panic("kernfs.Dentry.DecRef() called without holding a reference") + } +} + +// cacheLocked should be called after d's reference count becomes 0. The ref +// count check may happen before acquiring d.fs.mu so there might be a race +// condition where the ref count is increased again by the time the caller +// acquires d.fs.mu. This race is handled. +// Only reachable dentries are added to the cache. However, a dentry might +// become unreachable *while* it is in the cache due to invalidation. +// +// Preconditions: d.fs.mu must be locked for writing. +func (d *Dentry) cacheLocked(ctx context.Context) { + // Dentries with a non-zero reference count must be retained. (The only way + // to obtain a reference on a dentry with zero references is via path + // resolution, which requires d.fs.mu, so if d.refs is zero then it will + // remain zero while we hold d.fs.mu for writing.) + refs := atomic.LoadInt64(&d.refs) + if refs == -1 { + // Dentry has already been destroyed. + panic(fmt.Sprintf("cacheLocked called on a dentry which has already been destroyed: %v", d)) + } + if refs > 0 { + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentriesLen-- + d.cached = false + } + return + } + // If the dentry is deleted and invalidated or has no parent, then it is no + // longer reachable by path resolution and should be dropped immediately + // because it has zero references. + // Note that a dentry may not always have a parent; for example magic links + // as described in Inode.Getlink. + if isDead := d.VFSDentry().IsDead(); isDead || d.parent == nil { + if !isDead { + d.fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry()) + } + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentriesLen-- + d.cached = false + } + d.destroyLocked(ctx) + return + } + // If d is already cached, just move it to the front of the LRU. + if d.cached { + d.fs.cachedDentries.Remove(d) + d.fs.cachedDentries.PushFront(d) + return + } + // Cache the dentry, then evict the least recently used cached dentry if + // the cache becomes over-full. + d.fs.cachedDentries.PushFront(d) + d.fs.cachedDentriesLen++ + d.cached = true + if d.fs.cachedDentriesLen <= d.fs.MaxCachedDentries { + return + } + d.fs.evictCachedDentryLocked(ctx) + // Whether or not victim was destroyed, we brought fs.cachedDentriesLen + // back down to fs.opts.maxCachedDentries, so we don't loop. +} + +// Preconditions: +// * fs.mu must be locked for writing. +// * fs.cachedDentriesLen != 0. +func (fs *Filesystem) evictCachedDentryLocked(ctx context.Context) { + // Evict the least recently used dentry because cache size is greater than + // max cache size (configured on mount). + victim := fs.cachedDentries.Back() + fs.cachedDentries.Remove(victim) + fs.cachedDentriesLen-- + victim.cached = false + // victim.refs may have become non-zero from an earlier path resolution + // after it was inserted into fs.cachedDentries. + if atomic.LoadInt64(&victim.refs) == 0 { + if !victim.vfsd.IsDead() { + victim.parent.dirMu.Lock() + // Note that victim can't be a mount point (in any mount + // namespace), since VFS holds references on mount points. + fs.vfsfs.VirtualFilesystem().InvalidateDentry(ctx, victim.VFSDentry()) + delete(victim.parent.children, victim.name) + victim.parent.dirMu.Unlock() + } + victim.destroyLocked(ctx) + } + // Whether or not victim was destroyed, we brought fs.cachedDentriesLen + // back down to fs.MaxCachedDentries, so we don't loop. +} + +// destroyLocked destroys the dentry. +// +// Preconditions: +// * d.fs.mu must be locked for writing. +// * d.refs == 0. +// * d should have been removed from d.parent.children, i.e. d is not reachable +// by path traversal. +// * d.vfsd.IsDead() is true. +func (d *Dentry) destroyLocked(ctx context.Context) { + refs := atomic.LoadInt64(&d.refs) + switch refs { + case 0: + // Mark the dentry destroyed. + atomic.StoreInt64(&d.refs, -1) + case -1: + panic("dentry.destroyLocked() called on already destroyed dentry") + default: + panic("dentry.destroyLocked() called with references on the dentry") + } + + d.inode.DecRef(ctx) // IncRef from Init. + d.inode = nil + + if d.parent != nil { + d.parent.decRefLocked(ctx) + } + + refsvfs2.Unregister(d) +} + +// RefType implements refsvfs2.CheckedObject.Type. +func (d *Dentry) RefType() string { + return "kernfs.Dentry" +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *Dentry) LeakMessage() string { + return fmt.Sprintf("[kernfs.Dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +// +// This should only be set to true for debugging purposes, as it can generate an +// extremely large amount of output and drastically degrade performance. +func (d *Dentry) LogRefs() bool { + return false +} + +// InitRoot initializes this dentry as the root of the filesystem. +// +// Precondition: Caller must hold a reference on inode. +// +// Postcondition: Caller's reference on inode is transferred to the dentry. +func (d *Dentry) InitRoot(fs *Filesystem, inode Inode) { + d.Init(fs, inode) + fs.root = d + // Hold an extra reference on the root dentry. It is held by fs to prevent the + // root from being "cached" and subsequently evicted. + d.IncRef() +} + // Init initializes this dentry. // // Precondition: Caller must hold a reference on inode. @@ -197,6 +421,7 @@ func (d *Dentry) Init(fs *Filesystem, inode Inode) { d.vfsd.Init(d) d.fs = fs d.inode = inode + atomic.StoreInt64(&d.refs, 1) ftype := inode.Mode().FileType() if ftype == linux.ModeDirectory { d.flags |= dflagsIsDir @@ -204,7 +429,7 @@ func (d *Dentry) Init(fs *Filesystem, inode Inode) { if ftype == linux.ModeSymlink { d.flags |= dflagsIsSymlink } - d.EnableLeakCheck() + refsvfs2.Register(d) } // VFSDentry returns the generic vfs dentry for this kernfs dentry. @@ -222,32 +447,6 @@ func (d *Dentry) isSymlink() bool { return atomic.LoadUint32(&d.flags)&dflagsIsSymlink != 0 } -// DecRef implements vfs.DentryImpl.DecRef. -func (d *Dentry) DecRef(ctx context.Context) { - decRefParent := false - d.fs.mu.Lock() - d.DentryRefs.DecRef(func() { - d.inode.DecRef(ctx) // IncRef from Init. - d.inode = nil - if d.parent != nil { - // We will DecRef d.parent once all locks are dropped. - decRefParent = true - d.parent.dirMu.Lock() - // Remove d from parent.children. It might already have been - // removed due to invalidation. - if _, ok := d.parent.children[d.name]; ok { - delete(d.parent.children, d.name) - d.fs.VFSFilesystem().VirtualFilesystem().InvalidateDentry(ctx, d.VFSDentry()) - } - d.parent.dirMu.Unlock() - } - }) - d.fs.mu.Unlock() - if decRefParent { - d.parent.DecRef(ctx) // IncRef from Dentry.insertChild. - } -} - // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. // // Although Linux technically supports inotify on pseudo filesystems (inotify @@ -267,7 +466,9 @@ func (d *Dentry) OnZeroWatches(context.Context) {} // this dentry. This does not update the directory inode, so calling this on its // own isn't sufficient to insert a child into a directory. // -// Precondition: d must represent a directory inode. +// Preconditions: +// * d must represent a directory inode. +// * d.fs.mu must be locked for at least reading. func (d *Dentry) insertChild(name string, child *Dentry) { d.dirMu.Lock() d.insertChildLocked(name, child) @@ -280,6 +481,7 @@ func (d *Dentry) insertChild(name string, child *Dentry) { // Preconditions: // * d must represent a directory inode. // * d.dirMu must be locked. +// * d.fs.mu must be locked for at least reading. func (d *Dentry) insertChildLocked(name string, child *Dentry) { if !d.isDir() { panic(fmt.Sprintf("insertChildLocked called on non-directory Dentry: %+v.", d)) @@ -436,7 +638,7 @@ type inodeDirectory interface { // the inode is a directory. // // The child returned by Lookup will be hashed into the VFS dentry tree, - // atleast for the duration of the current FS operation. + // at least for the duration of the current FS operation. // // Lookup must return the child with an extra reference whose ownership is // transferred to the dentry that is created to point to that inode. If @@ -454,7 +656,7 @@ type inodeDirectory interface { // inside the entries returned by this IterDirents invocation. In other words, // 'offset' should be used to calculate each vfs.Dirent.NextOff as well as // the return value, while 'relOffset' is the place to start iteration. - IterDirents(ctx context.Context, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) + IterDirents(ctx context.Context, mnt *vfs.Mount, callback vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) } type inodeSymlink interface { diff --git a/pkg/sentry/fsimpl/kernfs/kernfs_test.go b/pkg/sentry/fsimpl/kernfs/kernfs_test.go index 82fa19c03..2418eec44 100644 --- a/pkg/sentry/fsimpl/kernfs/kernfs_test.go +++ b/pkg/sentry/fsimpl/kernfs/kernfs_test.go @@ -36,7 +36,7 @@ const staticFileContent = "This is sample content for a static test file." // RootDentryFn is a generator function for creating the root dentry of a test // filesystem. See newTestSystem. -type RootDentryFn func(*auth.Credentials, *filesystem) kernfs.Inode +type RootDentryFn func(context.Context, *auth.Credentials, *filesystem) kernfs.Inode // newTestSystem sets up a minimal environment for running a test, including an // instance of a test filesystem. Tests can control the contents of the @@ -72,10 +72,10 @@ type file struct { content string } -func (fs *filesystem) newFile(creds *auth.Credentials, content string) kernfs.Inode { +func (fs *filesystem) newFile(ctx context.Context, creds *auth.Credentials, content string) kernfs.Inode { f := &file{} f.content = content - f.DynamicBytesFile.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777) + f.DynamicBytesFile.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), f, 0777) return f } @@ -105,9 +105,9 @@ type readonlyDir struct { locks vfs.FileLocks } -func (fs *filesystem) newReadonlyDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newReadonlyDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { dir := &readonlyDir{} - dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) + dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) dir.EnableLeakCheck() dir.IncLinks(dir.OrderedChildren.Populate(contents)) @@ -142,10 +142,10 @@ type dir struct { fs *filesystem } -func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { dir := &dir{} dir.fs = fs - dir.attrs.Init(creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) + dir.attrs.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, fs.NextIno(), linux.ModeDirectory|mode) dir.OrderedChildren.Init(kernfs.OrderedChildrenOptions{Writable: true}) dir.EnableLeakCheck() @@ -169,22 +169,24 @@ func (d *dir) DecRef(ctx context.Context) { func (d *dir) NewDir(ctx context.Context, name string, opts vfs.MkdirOptions) (kernfs.Inode, error) { creds := auth.CredentialsFromContext(ctx) - dir := d.fs.newDir(creds, opts.Mode, nil) + dir := d.fs.newDir(ctx, creds, opts.Mode, nil) if err := d.OrderedChildren.Insert(name, dir); err != nil { dir.DecRef(ctx) return nil, err } + d.TouchCMtime(ctx) d.IncLinks(1) return dir, nil } func (d *dir) NewFile(ctx context.Context, name string, opts vfs.OpenOptions) (kernfs.Inode, error) { creds := auth.CredentialsFromContext(ctx) - f := d.fs.newFile(creds, "") + f := d.fs.newFile(ctx, creds, "") if err := d.OrderedChildren.Insert(name, f); err != nil { f.DecRef(ctx) return nil, err } + d.TouchCMtime(ctx) return f, nil } @@ -209,7 +211,7 @@ func (fsType) Release(ctx context.Context) {} func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opt vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { fs := &filesystem{} fs.VFSFilesystem().Init(vfsObj, &fst, fs) - root := fst.rootFn(creds, fs) + root := fst.rootFn(ctx, creds, fs) var d kernfs.Dentry d.Init(&fs.Filesystem, root) return fs.VFSFilesystem(), d.VFSDentry(), nil @@ -218,9 +220,9 @@ func (fst fsType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesyst // -------------------- Remainder of the file are test cases -------------------- func TestBasic(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "file1": fs.newFile(creds, staticFileContent), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "file1": fs.newFile(ctx, creds, staticFileContent), }) }) defer sys.Destroy() @@ -228,9 +230,9 @@ func TestBasic(t *testing.T) { } func TestMkdirGetDentry(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "dir1": fs.newDir(creds, 0755, nil), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "dir1": fs.newDir(ctx, creds, 0755, nil), }) }) defer sys.Destroy() @@ -243,9 +245,9 @@ func TestMkdirGetDentry(t *testing.T) { } func TestReadStaticFile(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "file1": fs.newFile(creds, staticFileContent), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "file1": fs.newFile(ctx, creds, staticFileContent), }) }) defer sys.Destroy() @@ -269,9 +271,9 @@ func TestReadStaticFile(t *testing.T) { } func TestCreateNewFileInStaticDir(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ - "dir1": fs.newDir(creds, 0755, nil), + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "dir1": fs.newDir(ctx, creds, 0755, nil), }) }) defer sys.Destroy() @@ -296,8 +298,8 @@ func TestCreateNewFileInStaticDir(t *testing.T) { } func TestDirFDReadWrite(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, nil) + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, nil) }) defer sys.Destroy() @@ -320,14 +322,14 @@ func TestDirFDReadWrite(t *testing.T) { } func TestDirFDIterDirents(t *testing.T) { - sys := newTestSystem(t, func(creds *auth.Credentials, fs *filesystem) kernfs.Inode { - return fs.newReadonlyDir(creds, 0755, map[string]kernfs.Inode{ + sys := newTestSystem(t, func(ctx context.Context, creds *auth.Credentials, fs *filesystem) kernfs.Inode { + return fs.newReadonlyDir(ctx, creds, 0755, map[string]kernfs.Inode{ // Fill root with nodes backed by various inode implementations. - "dir1": fs.newReadonlyDir(creds, 0755, nil), - "dir2": fs.newDir(creds, 0755, map[string]kernfs.Inode{ - "dir3": fs.newDir(creds, 0755, nil), + "dir1": fs.newReadonlyDir(ctx, creds, 0755, nil), + "dir2": fs.newDir(ctx, creds, 0755, map[string]kernfs.Inode{ + "dir3": fs.newDir(ctx, creds, 0755, nil), }), - "file1": fs.newFile(creds, staticFileContent), + "file1": fs.newFile(ctx, creds, staticFileContent), }) }) defer sys.Destroy() diff --git a/pkg/sentry/fsimpl/host/mmap.go b/pkg/sentry/fsimpl/kernfs/mmap_util.go index b51a17bed..bd6a134b4 100644 --- a/pkg/sentry/fsimpl/host/mmap.go +++ b/pkg/sentry/fsimpl/kernfs/mmap_util.go @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -package host +package kernfs import ( "gvisor.dev/gvisor/pkg/context" @@ -26,11 +26,14 @@ import ( // inodePlatformFile implements memmap.File. It exists solely because inode // cannot implement both kernfs.Inode.IncRef and memmap.File.IncRef. // -// inodePlatformFile should only be used if inode.canMap is true. -// // +stateify savable type inodePlatformFile struct { - *inode + // hostFD contains the host fd that this file was originally created from, + // which must be available at time of restore. + // + // This field is initialized at creation time and is immutable. + // inodePlatformFile does not own hostFD and hence should not close it. + hostFD int // fdRefsMu protects fdRefs. fdRefsMu sync.Mutex `state:"nosave"` @@ -43,12 +46,12 @@ type inodePlatformFile struct { fileMapper fsutil.HostFileMapper // fileMapperInitOnce is used to lazily initialize fileMapper. - fileMapperInitOnce sync.Once `state:"nosave"` // FIXME(gvisor.dev/issue/1663): not yet supported. + fileMapperInitOnce sync.Once `state:"nosave"` } +var _ memmap.File = (*inodePlatformFile)(nil) + // IncRef implements memmap.File.IncRef. -// -// Precondition: i.inode.canMap must be true. func (i *inodePlatformFile) IncRef(fr memmap.FileRange) { i.fdRefsMu.Lock() i.fdRefs.IncRefAndAccount(fr) @@ -56,8 +59,6 @@ func (i *inodePlatformFile) IncRef(fr memmap.FileRange) { } // DecRef implements memmap.File.DecRef. -// -// Precondition: i.inode.canMap must be true. func (i *inodePlatformFile) DecRef(fr memmap.FileRange) { i.fdRefsMu.Lock() i.fdRefs.DecRefAndAccount(fr) @@ -65,8 +66,6 @@ func (i *inodePlatformFile) DecRef(fr memmap.FileRange) { } // MapInternal implements memmap.File.MapInternal. -// -// Precondition: i.inode.canMap must be true. func (i *inodePlatformFile) MapInternal(fr memmap.FileRange, at usermem.AccessType) (safemem.BlockSeq, error) { return i.fileMapper.MapInternal(fr, i.hostFD, at.Write) } @@ -76,10 +75,32 @@ func (i *inodePlatformFile) FD() int { return i.hostFD } -// AddMapping implements memmap.Mappable.AddMapping. +// CachedMappable implements memmap.Mappable. This utility can be embedded in a +// kernfs.Inode that represents a host file to make the inode mappable. +// CachedMappable caches the mappings of the host file. CachedMappable must be +// initialized (via Init) with a hostFD before use. // -// Precondition: i.inode.canMap must be true. -func (i *inode) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { +// +stateify savable +type CachedMappable struct { + // mapsMu protects mappings. + mapsMu sync.Mutex `state:"nosave"` + + // mappings tracks mappings of hostFD into memmap.MappingSpaces. + mappings memmap.MappingSet + + // pf implements memmap.File for mappings backed by a host fd. + pf inodePlatformFile +} + +var _ memmap.Mappable = (*CachedMappable)(nil) + +// Init initializes i.pf. This must be called before using CachedMappable. +func (i *CachedMappable) Init(hostFD int) { + i.pf.hostFD = hostFD +} + +// AddMapping implements memmap.Mappable.AddMapping. +func (i *CachedMappable) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) error { i.mapsMu.Lock() mapped := i.mappings.AddMapping(ms, ar, offset, writable) for _, r := range mapped { @@ -90,9 +111,7 @@ func (i *inode) AddMapping(ctx context.Context, ms memmap.MappingSpace, ar userm } // RemoveMapping implements memmap.Mappable.RemoveMapping. -// -// Precondition: i.inode.canMap must be true. -func (i *inode) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { +func (i *CachedMappable) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar usermem.AddrRange, offset uint64, writable bool) { i.mapsMu.Lock() unmapped := i.mappings.RemoveMapping(ms, ar, offset, writable) for _, r := range unmapped { @@ -102,16 +121,12 @@ func (i *inode) RemoveMapping(ctx context.Context, ms memmap.MappingSpace, ar us } // CopyMapping implements memmap.Mappable.CopyMapping. -// -// Precondition: i.inode.canMap must be true. -func (i *inode) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { +func (i *CachedMappable) CopyMapping(ctx context.Context, ms memmap.MappingSpace, srcAR, dstAR usermem.AddrRange, offset uint64, writable bool) error { return i.AddMapping(ctx, ms, dstAR, offset, writable) } // Translate implements memmap.Mappable.Translate. -// -// Precondition: i.inode.canMap must be true. -func (i *inode) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { +func (i *CachedMappable) Translate(ctx context.Context, required, optional memmap.MappableRange, at usermem.AccessType) ([]memmap.Translation, error) { mr := optional return []memmap.Translation{ { @@ -124,10 +139,26 @@ func (i *inode) Translate(ctx context.Context, required, optional memmap.Mappabl } // InvalidateUnsavable implements memmap.Mappable.InvalidateUnsavable. -// -// Precondition: i.inode.canMap must be true. -func (i *inode) InvalidateUnsavable(ctx context.Context) error { +func (i *CachedMappable) InvalidateUnsavable(ctx context.Context) error { // We expect the same host fd across save/restore, so all translations // should be valid. return nil } + +// InvalidateRange invalidates the passed range on i.mappings. +func (i *CachedMappable) InvalidateRange(r memmap.MappableRange) { + i.mapsMu.Lock() + i.mappings.Invalidate(r, memmap.InvalidateOpts{ + // Compare Linux's mm/truncate.c:truncate_setsize() => + // truncate_pagecache() => + // mm/memory.c:unmap_mapping_range(evencows=1). + InvalidatePrivate: true, + }) + i.mapsMu.Unlock() +} + +// InitFileMapperOnce initializes the host file mapper. It ensures that the +// file mapper is initialized just once. +func (i *CachedMappable) InitFileMapperOnce() { + i.pf.fileMapperInitOnce.Do(i.pf.fileMapper.Init) +} diff --git a/pkg/sentry/fsimpl/kernfs/save_restore.go b/pkg/sentry/fsimpl/kernfs/save_restore.go new file mode 100644 index 000000000..f78509eb7 --- /dev/null +++ b/pkg/sentry/fsimpl/kernfs/save_restore.go @@ -0,0 +1,36 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package kernfs + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// afterLoad is invoked by stateify. +func (d *Dentry) afterLoad() { + if atomic.LoadInt64(&d.refs) >= 0 { + refsvfs2.Register(d) + } +} + +// afterLoad is invoked by stateify. +func (i *inodePlatformFile) afterLoad() { + if i.fileMapper.IsInited() { + // Ensure that we don't call i.fileMapper.Init() again. + i.fileMapperInitOnce.Do(func() {}) + } +} diff --git a/pkg/sentry/fsimpl/kernfs/symlink.go b/pkg/sentry/fsimpl/kernfs/symlink.go index 934cc6c9e..a0736c0d6 100644 --- a/pkg/sentry/fsimpl/kernfs/symlink.go +++ b/pkg/sentry/fsimpl/kernfs/symlink.go @@ -38,16 +38,16 @@ type StaticSymlink struct { var _ Inode = (*StaticSymlink)(nil) // NewStaticSymlink creates a new symlink file pointing to 'target'. -func NewStaticSymlink(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode { +func NewStaticSymlink(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, target string) Inode { inode := &StaticSymlink{} - inode.Init(creds, devMajor, devMinor, ino, target) + inode.Init(ctx, creds, devMajor, devMinor, ino, target) return inode } // Init initializes the instance. -func (s *StaticSymlink) Init(creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) { +func (s *StaticSymlink) Init(ctx context.Context, creds *auth.Credentials, devMajor uint32, devMinor uint32, ino uint64, target string) { s.target = target - s.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeSymlink|0777) + s.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeSymlink|0777) } // Readlink implements Inode.Readlink. diff --git a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go index d0ed17b18..463d77d79 100644 --- a/pkg/sentry/fsimpl/kernfs/synthetic_directory.go +++ b/pkg/sentry/fsimpl/kernfs/synthetic_directory.go @@ -41,17 +41,17 @@ type syntheticDirectory struct { var _ Inode = (*syntheticDirectory)(nil) -func newSyntheticDirectory(creds *auth.Credentials, perm linux.FileMode) Inode { +func newSyntheticDirectory(ctx context.Context, creds *auth.Credentials, perm linux.FileMode) Inode { inode := &syntheticDirectory{} - inode.Init(creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm) + inode.Init(ctx, creds, 0 /* devMajor */, 0 /* devMinor */, 0 /* ino */, perm) return inode } -func (dir *syntheticDirectory) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { +func (dir *syntheticDirectory) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("perm contains non-permission bits: %#o", perm)) } - dir.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.S_IFDIR|perm) + dir.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.S_IFDIR|perm) dir.OrderedChildren.Init(OrderedChildrenOptions{ Writable: true, }) @@ -76,11 +76,12 @@ func (dir *syntheticDirectory) NewDir(ctx context.Context, name string, opts vfs if !opts.ForSyntheticMountpoint { return nil, syserror.EPERM } - subdirI := newSyntheticDirectory(auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask) + subdirI := newSyntheticDirectory(ctx, auth.CredentialsFromContext(ctx), opts.Mode&linux.PermissionsMask) if err := dir.OrderedChildren.Insert(name, subdirI); err != nil { subdirI.DecRef(ctx) return nil, err } + dir.TouchCMtime(ctx) return subdirI, nil } diff --git a/pkg/sentry/fsimpl/overlay/BUILD b/pkg/sentry/fsimpl/overlay/BUILD index 1e11b0428..bf13bbbf4 100644 --- a/pkg/sentry/fsimpl/overlay/BUILD +++ b/pkg/sentry/fsimpl/overlay/BUILD @@ -23,6 +23,7 @@ go_library( "fstree.go", "overlay.go", "regular_file.go", + "save_restore.go", ], visibility = ["//pkg/sentry:internal"], deps = [ @@ -30,6 +31,8 @@ go_library( "//pkg/context", "//pkg/fspath", "//pkg/log", + "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/fsimpl/overlay/copy_up.go b/pkg/sentry/fsimpl/overlay/copy_up.go index 4506642ca..469f3a33d 100644 --- a/pkg/sentry/fsimpl/overlay/copy_up.go +++ b/pkg/sentry/fsimpl/overlay/copy_up.go @@ -409,7 +409,7 @@ func (d *dentry) copyUpDescendantsLocked(ctx context.Context, ds **[]*dentry) er if dirent.Name == "." || dirent.Name == ".." { continue } - child, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds) + child, _, err := d.fs.getChildLocked(ctx, d, dirent.Name, ds) if err != nil { return err } diff --git a/pkg/sentry/fsimpl/overlay/filesystem.go b/pkg/sentry/fsimpl/overlay/filesystem.go index 78a01bbb7..bc07d72c0 100644 --- a/pkg/sentry/fsimpl/overlay/filesystem.go +++ b/pkg/sentry/fsimpl/overlay/filesystem.go @@ -22,6 +22,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sentry/vfs" @@ -121,63 +122,63 @@ func (fs *filesystem) renameMuUnlockAndCheckDrop(ctx context.Context, ds **[]*de // * fs.renameMu must be locked. // * d.dirMu must be locked. // * !rp.Done(). -func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, error) { +func (fs *filesystem) stepLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, mayFollowSymlinks bool, ds **[]*dentry) (*dentry, lookupLayer, error) { if !d.isDir() { - return nil, syserror.ENOTDIR + return nil, lookupLayerNone, syserror.ENOTDIR } if err := d.checkPermissions(rp.Credentials(), vfs.MayExec); err != nil { - return nil, err + return nil, lookupLayerNone, err } afterSymlink: name := rp.Component() if name == "." { rp.Advance() - return d, nil + return d, d.topLookupLayer(), nil } if name == ".." { if isRoot, err := rp.CheckRoot(ctx, &d.vfsd); err != nil { - return nil, err + return nil, lookupLayerNone, err } else if isRoot || d.parent == nil { rp.Advance() - return d, nil + return d, d.topLookupLayer(), nil } if err := rp.CheckMount(ctx, &d.parent.vfsd); err != nil { - return nil, err + return nil, lookupLayerNone, err } rp.Advance() - return d.parent, nil + return d.parent, d.parent.topLookupLayer(), nil } - child, err := fs.getChildLocked(ctx, d, name, ds) + child, topLookupLayer, err := fs.getChildLocked(ctx, d, name, ds) if err != nil { - return nil, err + return nil, topLookupLayer, err } if err := rp.CheckMount(ctx, &child.vfsd); err != nil { - return nil, err + return nil, lookupLayerNone, err } if child.isSymlink() && mayFollowSymlinks && rp.ShouldFollowSymlink() { target, err := child.readlink(ctx) if err != nil { - return nil, err + return nil, lookupLayerNone, err } if err := rp.HandleSymlink(target); err != nil { - return nil, err + return nil, topLookupLayer, err } goto afterSymlink // don't check the current directory again } rp.Advance() - return child, nil + return child, topLookupLayer, nil } // Preconditions: // * fs.renameMu must be locked. // * d.dirMu must be locked. -func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { +func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, lookupLayer, error) { if child, ok := parent.children[name]; ok { - return child, nil + return child, child.topLookupLayer(), nil } - child, err := fs.lookupLocked(ctx, parent, name) + child, topLookupLayer, err := fs.lookupLocked(ctx, parent, name) if err != nil { - return nil, err + return nil, topLookupLayer, err } if parent.children == nil { parent.children = make(map[string]*dentry) @@ -185,16 +186,16 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s parent.children[name] = child // child's refcount is initially 0, so it may be dropped after traversal. *ds = appendDentry(*ds, child) - return child, nil + return child, topLookupLayer, nil } // Preconditions: // * fs.renameMu must be locked. // * parent.dirMu must be locked. -func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) { +func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name string) (*dentry, lookupLayer, error) { childPath := fspath.Parse(name) child := fs.newDentry() - existsOnAnyLayer := false + topLookupLayer := lookupLayerNone var lookupErr error vfsObj := fs.vfsfs.VirtualFilesystem() @@ -215,7 +216,7 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str defer childVD.DecRef(ctx) mask := uint32(linux.STATX_TYPE) - if !existsOnAnyLayer { + if topLookupLayer == lookupLayerNone { // Mode, UID, GID, and (for non-directories) inode number come from // the topmost layer on which the file exists. mask |= linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID | linux.STATX_INO @@ -238,10 +239,13 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str if isWhiteout(&stat) { // This is a whiteout, so it "doesn't exist" on this layer, and // layers below this one are ignored. + if isUpper { + topLookupLayer = lookupLayerUpperWhiteout + } return false } isDir := stat.Mode&linux.S_IFMT == linux.S_IFDIR - if existsOnAnyLayer && !isDir { + if topLookupLayer != lookupLayerNone && !isDir { // Directories are not merged with non-directory files from lower // layers; instead, layers including and below the first // non-directory file are ignored. (This file must be a directory @@ -258,8 +262,12 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str } else { child.lowerVDs = append(child.lowerVDs, childVD) } - if !existsOnAnyLayer { - existsOnAnyLayer = true + if topLookupLayer == lookupLayerNone { + if isUpper { + topLookupLayer = lookupLayerUpper + } else { + topLookupLayer = lookupLayerLower + } child.mode = uint32(stat.Mode) child.uid = stat.UID child.gid = stat.GID @@ -288,11 +296,11 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str if lookupErr != nil { child.destroyLocked(ctx) - return nil, lookupErr + return nil, topLookupLayer, lookupErr } - if !existsOnAnyLayer { + if !topLookupLayer.existsInOverlay() { child.destroyLocked(ctx) - return nil, syserror.ENOENT + return nil, topLookupLayer, syserror.ENOENT } // Device and inode numbers were copied from the topmost layer above; @@ -302,14 +310,20 @@ func (fs *filesystem) lookupLocked(ctx context.Context, parent *dentry, name str child.devMinor = fs.dirDevMinor child.ino = fs.newDirIno() } else if !child.upperVD.Ok() { + childDevMinor, err := fs.getLowerDevMinor(child.devMajor, child.devMinor) + if err != nil { + ctx.Infof("overlay.filesystem.lookupLocked: failed to map lower layer device number (%d, %d) to an overlay-specific device number: %v", child.devMajor, child.devMinor, err) + child.destroyLocked(ctx) + return nil, topLookupLayer, err + } child.devMajor = linux.UNNAMED_MAJOR - child.devMinor = fs.lowerDevMinors[child.lowerVDs[0].Mount().Filesystem()] + child.devMinor = childDevMinor } parent.IncRef() child.parent = parent child.name = name - return child, nil + return child, topLookupLayer, nil } // lookupLayerLocked is similar to lookupLocked, but only returns information @@ -408,7 +422,7 @@ func (ll lookupLayer) existsInOverlay() bool { func (fs *filesystem) walkParentDirLocked(ctx context.Context, rp *vfs.ResolvingPath, d *dentry, ds **[]*dentry) (*dentry, error) { for !rp.Final() { d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + next, _, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) d.dirMu.Unlock() if err != nil { return nil, err @@ -428,7 +442,7 @@ func (fs *filesystem) resolveLocked(ctx context.Context, rp *vfs.ResolvingPath, d := rp.Start().Impl().(*dentry) for !rp.Done() { d.dirMu.Lock() - next, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) + next, _, err := fs.stepLocked(ctx, rp, d, true /* mayFollowSymlinks */, ds) d.dirMu.Unlock() if err != nil { return nil, err @@ -463,9 +477,6 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir if name == "." || name == ".." { return syserror.EEXIST } - if !dir && rp.MustBeDir() { - return syserror.ENOENT - } if parent.vfsd.IsDead() { return syserror.ENOENT } @@ -489,6 +500,10 @@ func (fs *filesystem) doCreateAt(ctx context.Context, rp *vfs.ResolvingPath, dir return syserror.EEXIST } + if !dir && rp.MustBeDir() { + return syserror.ENOENT + } + // Ensure that the parent directory is copied-up so that we can create the // new file in the upper layer. if err := parent.copyUpLocked(ctx); err != nil { @@ -791,9 +806,9 @@ afterTrailingSymlink: } // Determine whether or not we need to create a file. parent.dirMu.Lock() - child, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) + child, topLookupLayer, err := fs.stepLocked(ctx, rp, parent, false /* mayFollowSymlinks */, &ds) if err == syserror.ENOENT && mayCreate { - fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds) + fd, err := fs.createAndOpenLocked(ctx, rp, parent, &opts, &ds, topLookupLayer == lookupLayerUpperWhiteout) parent.dirMu.Unlock() return fd, err } @@ -893,7 +908,7 @@ func (d *dentry) openCopiedUp(ctx context.Context, rp *vfs.ResolvingPath, opts * // Preconditions: // * parent.dirMu must be locked. // * parent does not already contain a child named rp.Component(). -func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry) (*vfs.FileDescription, error) { +func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.ResolvingPath, parent *dentry, opts *vfs.OpenOptions, ds **[]*dentry, haveUpperWhiteout bool) (*vfs.FileDescription, error) { creds := rp.Credentials() if err := parent.checkPermissions(creds, vfs.MayWrite); err != nil { return nil, err @@ -918,19 +933,12 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving Start: parent.upperVD, Path: fspath.Parse(childName), } - // We don't know if a whiteout exists on the upper layer; speculatively - // unlink it. - // - // TODO(gvisor.dev/issue/1199): Modify OpenAt => stepLocked so that we do - // know whether a whiteout exists. - var haveUpperWhiteout bool - switch err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err { - case nil: - haveUpperWhiteout = true - case syserror.ENOENT: - haveUpperWhiteout = false - default: - return nil, err + // Unlink the whiteout if it exists. + if haveUpperWhiteout { + if err := vfsObj.UnlinkAt(ctx, fs.creds, &pop); err != nil { + log.Warningf("overlay.filesystem.createAndOpenLocked: failed to unlink whiteout: %v", err) + return nil, err + } } // Create the file on the upper layer, and get an FD representing it. upperFD, err := vfsObj.OpenAt(ctx, fs.creds, &pop, &vfs.OpenOptions{ @@ -961,7 +969,7 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving } // Re-lookup to get a dentry representing the new file, which is needed for // the returned FD. - child, err := fs.getChildLocked(ctx, parent, childName, ds) + child, _, err := fs.getChildLocked(ctx, parent, childName, ds) if err != nil { if cleanupErr := vfsObj.UnlinkAt(ctx, fs.creds, &pop); cleanupErr != nil { panic(fmt.Sprintf("unrecoverable overlayfs inconsistency: failed to delete upper layer file after OpenAt(O_CREAT) dentry lookup failure: %v", cleanupErr)) @@ -970,7 +978,10 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving } return nil, err } - // Finally construct the overlay FD. + // Finally construct the overlay FD. Below this point, we don't perform + // cleanup (the file was created successfully even if we can no longer open + // it for some reason). + parent.dirents = nil upperFlags := upperFD.StatusFlags() fd := ®ularFileFD{ copiedUp: true, @@ -981,8 +992,6 @@ func (fs *filesystem) createAndOpenLocked(ctx context.Context, rp *vfs.Resolving upperFDOpts := upperFD.Options() if err := fd.vfsfd.Init(fd, upperFlags, mnt, &child.vfsd, &upperFDOpts); err != nil { upperFD.DecRef(ctx) - // Don't bother with cleanup; the file was created successfully, we - // just can't open it anymore for some reason. return nil, err } parent.watches.Notify(ctx, childName, linux.IN_CREATE, 0 /* cookie */, vfs.PathEvent, false /* unlinked */) @@ -1040,7 +1049,7 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa // directory, we need to check for write permission on it. oldParent.dirMu.Lock() defer oldParent.dirMu.Unlock() - renamed, err := fs.getChildLocked(ctx, oldParent, oldName, &ds) + renamed, _, err := fs.getChildLocked(ctx, oldParent, oldName, &ds) if err != nil { return err } @@ -1072,20 +1081,17 @@ func (fs *filesystem) RenameAt(ctx context.Context, rp *vfs.ResolvingPath, oldPa if newParent.vfsd.IsDead() { return syserror.ENOENT } - replacedLayer, err := fs.lookupLayerLocked(ctx, newParent, newName) - if err != nil { - return err - } var ( - replaced *dentry - replacedVFSD *vfs.Dentry - whiteouts map[string]bool + replaced *dentry + replacedVFSD *vfs.Dentry + replacedLayer lookupLayer + whiteouts map[string]bool ) - if replacedLayer.existsInOverlay() { - replaced, err = fs.getChildLocked(ctx, newParent, newName, &ds) - if err != nil { - return err - } + replaced, replacedLayer, err = fs.getChildLocked(ctx, newParent, newName, &ds) + if err != nil && err != syserror.ENOENT { + return err + } + if replaced != nil { replacedVFSD = &replaced.vfsd if replaced.isDir() { if !renamed.isDir() { @@ -1289,7 +1295,7 @@ func (fs *filesystem) RmdirAt(ctx context.Context, rp *vfs.ResolvingPath) error // Unlike UnlinkAt, we need a dentry representing the child directory being // removed in order to verify that it's empty. - child, err := fs.getChildLocked(ctx, parent, name, &ds) + child, _, err := fs.getChildLocked(ctx, parent, name, &ds) if err != nil { return err } @@ -1541,7 +1547,7 @@ func (fs *filesystem) UnlinkAt(ctx context.Context, rp *vfs.ResolvingPath) error if parentMode&linux.S_ISVTX != 0 { // If the parent's sticky bit is set, we need a child dentry to get // its owner. - child, err = fs.getChildLocked(ctx, parent, name, &ds) + child, _, err = fs.getChildLocked(ctx, parent, name, &ds) if err != nil { return err } diff --git a/pkg/sentry/fsimpl/overlay/overlay.go b/pkg/sentry/fsimpl/overlay/overlay.go index 4c5de8d32..73130bc8d 100644 --- a/pkg/sentry/fsimpl/overlay/overlay.go +++ b/pkg/sentry/fsimpl/overlay/overlay.go @@ -22,6 +22,7 @@ // filesystem.renameMu // dentry.dirMu // dentry.copyMu +// filesystem.devMu // *** "memmap.Mappable locks" below this point // dentry.mapsMu // *** "memmap.Mappable locks taken by Translate" below this point @@ -33,12 +34,14 @@ package overlay import ( + "fmt" "strings" "sync/atomic" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" + "gvisor.dev/gvisor/pkg/refsvfs2" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/memmap" @@ -99,10 +102,15 @@ type filesystem struct { // is immutable. dirDevMinor uint32 - // lowerDevMinors maps lower layer filesystems to device minor numbers - // assigned to non-directory files originating from that filesystem. - // lowerDevMinors is immutable. - lowerDevMinors map[*vfs.Filesystem]uint32 + // lowerDevMinors maps device numbers from lower layer filesystems to + // device minor numbers assigned to non-directory files originating from + // that filesystem. (This remapping is necessary for lower layers because a + // file on a lower layer, and that same file on an overlay, are + // distinguishable because they will diverge after copy-up; this isn't true + // for non-directory files already on the upper layer.) lowerDevMinors is + // protected by devMu. + devMu sync.Mutex `state:"nosave"` + lowerDevMinors map[layerDevNumber]uint32 // renameMu synchronizes renaming with non-renaming operations in order to // ensure consistent lock ordering between dentry.dirMu in different @@ -114,78 +122,69 @@ type filesystem struct { lastDirIno uint64 } +// +stateify savable +type layerDevNumber struct { + major uint32 + minor uint32 +} + // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, source string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { mopts := vfs.GenericParseMountOptions(opts.Data) fsoptsRaw := opts.InternalData - fsopts, haveFSOpts := fsoptsRaw.(FilesystemOptions) - if fsoptsRaw != nil && !haveFSOpts { + fsopts, ok := fsoptsRaw.(FilesystemOptions) + if fsoptsRaw != nil && !ok { ctx.Infof("overlay.FilesystemType.GetFilesystem: GetFilesystemOptions.InternalData has type %T, wanted overlay.FilesystemOptions or nil", fsoptsRaw) return nil, nil, syserror.EINVAL } - if haveFSOpts { - if len(fsopts.LowerRoots) == 0 { - ctx.Infof("overlay.FilesystemType.GetFilesystem: LowerRoots must be non-empty") + vfsroot := vfs.RootFromContext(ctx) + if vfsroot.Ok() { + defer vfsroot.DecRef(ctx) + } + + if upperPathname, ok := mopts["upperdir"]; ok { + if fsopts.UpperRoot.Ok() { + ctx.Infof("overlay.FilesystemType.GetFilesystem: both upperdir and FilesystemOptions.UpperRoot are specified") return nil, nil, syserror.EINVAL } - if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() { - ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two LowerRoots are required when UpperRoot is unspecified") + delete(mopts, "upperdir") + // Linux overlayfs also requires a workdir when upperdir is + // specified; we don't, so silently ignore this option. + delete(mopts, "workdir") + upperPath := fspath.Parse(upperPathname) + if !upperPath.Absolute { + ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) return nil, nil, syserror.EINVAL } - // We don't enforce a maximum number of lower layers when not - // configured by applications; the sandbox owner can have an overlay - // filesystem with any number of lower layers. - } else { - vfsroot := vfs.RootFromContext(ctx) - defer vfsroot.DecRef(ctx) - upperPathname, ok := mopts["upperdir"] - if ok { - delete(mopts, "upperdir") - // Linux overlayfs also requires a workdir when upperdir is - // specified; we don't, so silently ignore this option. - delete(mopts, "workdir") - upperPath := fspath.Parse(upperPathname) - if !upperPath.Absolute { - ctx.Infof("overlay.FilesystemType.GetFilesystem: upperdir %q must be absolute", upperPathname) - return nil, nil, syserror.EINVAL - } - upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ - Root: vfsroot, - Start: vfsroot, - Path: upperPath, - FollowFinalSymlink: true, - }, &vfs.GetDentryOptions{ - CheckSearchable: true, - }) - if err != nil { - ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) - return nil, nil, err - } - defer upperRoot.DecRef(ctx) - privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */) - if err != nil { - ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) - return nil, nil, err - } - defer privateUpperRoot.DecRef(ctx) - fsopts.UpperRoot = privateUpperRoot + upperRoot, err := vfsObj.GetDentryAt(ctx, creds, &vfs.PathOperation{ + Root: vfsroot, + Start: vfsroot, + Path: upperPath, + FollowFinalSymlink: true, + }, &vfs.GetDentryOptions{ + CheckSearchable: true, + }) + if err != nil { + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve upperdir %q: %v", upperPathname, err) + return nil, nil, err + } + privateUpperRoot, err := clonePrivateMount(vfsObj, upperRoot, false /* forceReadOnly */) + upperRoot.DecRef(ctx) + if err != nil { + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of upperdir %q: %v", upperPathname, err) + return nil, nil, err } - lowerPathnamesStr, ok := mopts["lowerdir"] - if !ok { - ctx.Infof("overlay.FilesystemType.GetFilesystem: missing required option lowerdir") + defer privateUpperRoot.DecRef(ctx) + fsopts.UpperRoot = privateUpperRoot + } + + if lowerPathnamesStr, ok := mopts["lowerdir"]; ok { + if len(fsopts.LowerRoots) != 0 { + ctx.Infof("overlay.FilesystemType.GetFilesystem: both lowerdir and FilesystemOptions.LowerRoots are specified") return nil, nil, syserror.EINVAL } delete(mopts, "lowerdir") lowerPathnames := strings.Split(lowerPathnamesStr, ":") - const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK - if len(lowerPathnames) < 2 && !fsopts.UpperRoot.Ok() { - ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lowerdirs are required when upperdir is unspecified") - return nil, nil, syserror.EINVAL - } - if len(lowerPathnames) > maxLowerLayers { - ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lowerdirs specified, maximum %d", len(lowerPathnames), maxLowerLayers) - return nil, nil, syserror.EINVAL - } for _, lowerPathname := range lowerPathnames { lowerPath := fspath.Parse(lowerPathname) if !lowerPath.Absolute { @@ -204,8 +203,8 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to resolve lowerdir %q: %v", lowerPathname, err) return nil, nil, err } - defer lowerRoot.DecRef(ctx) privateLowerRoot, err := clonePrivateMount(vfsObj, lowerRoot, true /* forceReadOnly */) + lowerRoot.DecRef(ctx) if err != nil { ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to make private bind mount of lowerdir %q: %v", lowerPathname, err) return nil, nil, err @@ -214,31 +213,31 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fsopts.LowerRoots = append(fsopts.LowerRoots, privateLowerRoot) } } + if len(mopts) != 0 { ctx.Infof("overlay.FilesystemType.GetFilesystem: unused options: %v", mopts) return nil, nil, syserror.EINVAL } - // Allocate device numbers. + if len(fsopts.LowerRoots) == 0 { + ctx.Infof("overlay.FilesystemType.GetFilesystem: at least one lower layer is required") + return nil, nil, syserror.EINVAL + } + if len(fsopts.LowerRoots) < 2 && !fsopts.UpperRoot.Ok() { + ctx.Infof("overlay.FilesystemType.GetFilesystem: at least two lower layers are required when no upper layer is present") + return nil, nil, syserror.EINVAL + } + const maxLowerLayers = 500 // Linux: fs/overlay/super.c:OVL_MAX_STACK + if len(fsopts.LowerRoots) > maxLowerLayers { + ctx.Infof("overlay.FilesystemType.GetFilesystem: %d lower layers specified, maximum %d", len(fsopts.LowerRoots), maxLowerLayers) + return nil, nil, syserror.EINVAL + } + + // Allocate dirDevMinor. lowerDevMinors are allocated dynamically. dirDevMinor, err := vfsObj.GetAnonBlockDevMinor() if err != nil { return nil, nil, err } - lowerDevMinors := make(map[*vfs.Filesystem]uint32) - for _, lowerRoot := range fsopts.LowerRoots { - lowerFS := lowerRoot.Mount().Filesystem() - if _, ok := lowerDevMinors[lowerFS]; !ok { - devMinor, err := vfsObj.GetAnonBlockDevMinor() - if err != nil { - vfsObj.PutAnonBlockDevMinor(dirDevMinor) - for _, lowerDevMinor := range lowerDevMinors { - vfsObj.PutAnonBlockDevMinor(lowerDevMinor) - } - return nil, nil, err - } - lowerDevMinors[lowerFS] = devMinor - } - } // Take extra references held by the filesystem. if fsopts.UpperRoot.Ok() { @@ -252,7 +251,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt opts: fsopts, creds: creds.Fork(), dirDevMinor: dirDevMinor, - lowerDevMinors: lowerDevMinors, + lowerDevMinors: make(map[layerDevNumber]uint32), } fs.vfsfs.Init(vfsObj, &fstype, fs) @@ -302,7 +301,14 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt root.ino = fs.newDirIno() } else if !root.upperVD.Ok() { root.devMajor = linux.UNNAMED_MAJOR - root.devMinor = fs.lowerDevMinors[root.lowerVDs[0].Mount().Filesystem()] + rootDevMinor, err := fs.getLowerDevMinor(rootStat.DevMajor, rootStat.DevMinor) + if err != nil { + ctx.Infof("overlay.FilesystemType.GetFilesystem: failed to get device number for root: %v", err) + root.destroyLocked(ctx) + fs.vfsfs.DecRef(ctx) + return nil, nil, err + } + root.devMinor = rootDevMinor root.ino = rootStat.Ino } else { root.devMajor = rootStat.DevMajor @@ -375,6 +381,21 @@ func (fs *filesystem) newDirIno() uint64 { return atomic.AddUint64(&fs.lastDirIno, 1) } +func (fs *filesystem) getLowerDevMinor(layerMajor, layerMinor uint32) (uint32, error) { + fs.devMu.Lock() + defer fs.devMu.Unlock() + orig := layerDevNumber{layerMajor, layerMinor} + if minor, ok := fs.lowerDevMinors[orig]; ok { + return minor, nil + } + minor, err := fs.vfsfs.VirtualFilesystem().GetAnonBlockDevMinor() + if err != nil { + return 0, err + } + fs.lowerDevMinors[orig] = minor + return minor, nil +} + // dentry implements vfs.DentryImpl. // // +stateify savable @@ -458,9 +479,9 @@ type dentry struct { // // - isMappable is non-zero iff wrappedMappable is non-nil. isMappable is // accessed using atomic memory operations. - mapsMu sync.Mutex + mapsMu sync.Mutex `state:"nosave"` lowerMappings memmap.MappingSet - dataMu sync.RWMutex + dataMu sync.RWMutex `state:"nosave"` wrappedMappable memmap.Mappable isMappable uint32 @@ -484,6 +505,7 @@ func (fs *filesystem) newDentry() *dentry { } d.lowerVDs = d.inlineLowerVDs[:0] d.vfsd.Init(d) + refsvfs2.Register(d) return d } @@ -491,17 +513,19 @@ func (fs *filesystem) newDentry() *dentry { func (d *dentry) IncRef() { // d.refs may be 0 if d.fs.renameMu is locked, which serializes against // d.checkDropLocked(). - atomic.AddInt64(&d.refs, 1) + r := atomic.AddInt64(&d.refs, 1) + refsvfs2.LogIncRef(d, r) } // TryIncRef implements vfs.DentryImpl.TryIncRef. func (d *dentry) TryIncRef() bool { for { - refs := atomic.LoadInt64(&d.refs) - if refs <= 0 { + r := atomic.LoadInt64(&d.refs) + if r <= 0 { return false } - if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) { + if atomic.CompareAndSwapInt64(&d.refs, r, r+1) { + refsvfs2.LogTryIncRef(d, r+1) return true } } @@ -509,15 +533,27 @@ func (d *dentry) TryIncRef() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *dentry) DecRef(ctx context.Context) { - if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { + r := atomic.AddInt64(&d.refs, -1) + refsvfs2.LogDecRef(d, r) + if r == 0 { d.fs.renameMu.Lock() d.checkDropLocked(ctx) d.fs.renameMu.Unlock() - } else if refs < 0 { + } else if r < 0 { panic("overlay.dentry.DecRef() called without holding a reference") } } +func (d *dentry) decRefLocked(ctx context.Context) { + r := atomic.AddInt64(&d.refs, -1) + refsvfs2.LogDecRef(d, r) + if r == 0 { + d.checkDropLocked(ctx) + } else if r < 0 { + panic("overlay.dentry.decRefLocked() called without holding a reference") + } +} + // checkDropLocked should be called after d's reference count becomes 0 or it // becomes deleted. // @@ -577,12 +613,27 @@ func (d *dentry) destroyLocked(ctx context.Context) { d.parent.dirMu.Unlock() // Drop the reference held by d on its parent without recursively // locking d.fs.renameMu. - if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 { - d.parent.checkDropLocked(ctx) - } else if refs < 0 { - panic("overlay.dentry.DecRef() called without holding a reference") - } + d.parent.decRefLocked(ctx) } + refsvfs2.Unregister(d) +} + +// RefType implements refsvfs2.CheckedObject.Type. +func (d *dentry) RefType() string { + return "overlay.dentry" +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *dentry) LeakMessage() string { + return fmt.Sprintf("[overlay.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +// +// This should only be set to true for debugging purposes, as it can generate an +// extremely large amount of output and drastically degrade performance. +func (d *dentry) LogRefs() bool { + return false } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. @@ -645,6 +696,13 @@ func (d *dentry) topLayer() vfs.VirtualDentry { return vd } +func (d *dentry) topLookupLayer() lookupLayer { + if d.upperVD.Ok() { + return lookupLayerUpper + } + return lookupLayerLower +} + func (d *dentry) checkPermissions(creds *auth.Credentials, ats vfs.AccessTypes) error { return vfs.GenericCheckPermissions(creds, ats, linux.FileMode(atomic.LoadUint32(&d.mode)), auth.KUID(atomic.LoadUint32(&d.uid)), auth.KGID(atomic.LoadUint32(&d.gid))) } diff --git a/pkg/sentry/fsimpl/overlay/save_restore.go b/pkg/sentry/fsimpl/overlay/save_restore.go new file mode 100644 index 000000000..54809f16c --- /dev/null +++ b/pkg/sentry/fsimpl/overlay/save_restore.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package overlay + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +func (d *dentry) afterLoad() { + if atomic.LoadInt64(&d.refs) != -1 { + refsvfs2.Register(d) + } +} diff --git a/pkg/sentry/fsimpl/pipefs/pipefs.go b/pkg/sentry/fsimpl/pipefs/pipefs.go index e44b79b68..0ecb592cf 100644 --- a/pkg/sentry/fsimpl/pipefs/pipefs.go +++ b/pkg/sentry/fsimpl/pipefs/pipefs.go @@ -101,7 +101,7 @@ type inode struct { func newInode(ctx context.Context, fs *filesystem) *inode { creds := auth.CredentialsFromContext(ctx) return &inode{ - pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize), + pipe: pipe.NewVFSPipe(false /* isNamed */, pipe.DefaultPipeSize), ino: fs.Filesystem.NextIno(), uid: creds.EffectiveKUID, gid: creds.EffectiveKGID, diff --git a/pkg/sentry/fsimpl/proc/BUILD b/pkg/sentry/fsimpl/proc/BUILD index 2e086e34c..5196a2a80 100644 --- a/pkg/sentry/fsimpl/proc/BUILD +++ b/pkg/sentry/fsimpl/proc/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "fd_dir_inode_refs.go", package = "proc", prefix = "fdDirInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "fdDirInode", }, @@ -19,7 +19,7 @@ go_template_instance( out = "fd_info_dir_inode_refs.go", package = "proc", prefix = "fdInfoDirInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "fdInfoDirInode", }, @@ -30,7 +30,7 @@ go_template_instance( out = "subtasks_inode_refs.go", package = "proc", prefix = "subtasksInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "subtasksInode", }, @@ -41,7 +41,7 @@ go_template_instance( out = "task_inode_refs.go", package = "proc", prefix = "taskInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "taskInode", }, @@ -52,7 +52,7 @@ go_template_instance( out = "tasks_inode_refs.go", package = "proc", prefix = "tasksInode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "tasksInode", }, @@ -82,6 +82,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/fs/lock", "//pkg/sentry/fsbridge", diff --git a/pkg/sentry/fsimpl/proc/filesystem.go b/pkg/sentry/fsimpl/proc/filesystem.go index fd70a07de..8716d0a3c 100644 --- a/pkg/sentry/fsimpl/proc/filesystem.go +++ b/pkg/sentry/fsimpl/proc/filesystem.go @@ -17,6 +17,7 @@ package proc import ( "fmt" + "strconv" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -24,10 +25,14 @@ import ( "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" ) -// Name is the default filesystem name. -const Name = "proc" +const ( + // Name is the default filesystem name. + Name = "proc" + defaultMaxCachedDentries = uint64(1000) +) // FilesystemType is the factory class for procfs. // @@ -63,9 +68,22 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF if err != nil { return nil, nil, err } + + mopts := vfs.GenericParseMountOptions(opts.Data) + maxCachedDentries := defaultMaxCachedDentries + if str, ok := mopts["dentry_cache_limit"]; ok { + delete(mopts, "dentry_cache_limit") + maxCachedDentries, err = strconv.ParseUint(str, 10, 64) + if err != nil { + ctx.Warningf("proc.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) + return nil, nil, syserror.EINVAL + } + } + procfs := &filesystem{ devMinor: devMinor, } + procfs.MaxCachedDentries = maxCachedDentries procfs.VFSFilesystem().Init(vfsObj, &ft, procfs) var cgroups map[string]string @@ -74,9 +92,9 @@ func (ft FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualF cgroups = data.Cgroups } - inode := procfs.newTasksInode(k, pidns, cgroups) + inode := procfs.newTasksInode(ctx, k, pidns, cgroups) var dentry kernfs.Dentry - dentry.Init(&procfs.Filesystem, inode) + dentry.InitRoot(&procfs.Filesystem, inode) return procfs.VFSFilesystem(), dentry.VFSDentry(), nil } @@ -94,11 +112,11 @@ type dynamicInode interface { kernfs.Inode vfs.DynamicBytesSource - Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) + Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, data vfs.DynamicBytesSource, perm linux.FileMode) } -func (fs *filesystem) newInode(creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode { - inode.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm) +func (fs *filesystem) newInode(ctx context.Context, creds *auth.Credentials, perm linux.FileMode, inode dynamicInode) dynamicInode { + inode.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), inode, perm) return inode } @@ -114,8 +132,8 @@ func newStaticFile(data string) *staticFile { return &staticFile{StaticData: vfs.StaticData{Data: data}} } -func (fs *filesystem) newStaticDir(creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode { - return kernfs.NewStaticDir(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{ +func (fs *filesystem) newStaticDir(ctx context.Context, creds *auth.Credentials, children map[string]kernfs.Inode) kernfs.Inode { + return kernfs.NewStaticDir(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0555, children, kernfs.GenericDirectoryFDOptions{ SeekEnd: kernfs.SeekEndZero, }) } diff --git a/pkg/sentry/fsimpl/proc/subtasks.go b/pkg/sentry/fsimpl/proc/subtasks.go index bad2fab4f..cb3c5e0fd 100644 --- a/pkg/sentry/fsimpl/proc/subtasks.go +++ b/pkg/sentry/fsimpl/proc/subtasks.go @@ -58,7 +58,7 @@ func (fs *filesystem) newSubtasks(task *kernel.Task, pidns *kernel.PIDNamespace, cgroupControllers: cgroupControllers, } // Note: credentials are overridden by taskOwnedInode. - subInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + subInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) subInode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) subInode.EnableLeakCheck() @@ -84,7 +84,7 @@ func (i *subtasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *subtasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (i *subtasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { tasks := i.task.ThreadGroup().MemberIDs(i.pidns) if len(tasks) == 0 { return offset, syserror.ENOENT diff --git a/pkg/sentry/fsimpl/proc/task.go b/pkg/sentry/fsimpl/proc/task.go index b63a4eca0..19011b010 100644 --- a/pkg/sentry/fsimpl/proc/task.go +++ b/pkg/sentry/fsimpl/proc/task.go @@ -64,6 +64,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace "gid_map": fs.newTaskOwnedInode(task, fs.NextIno(), 0644, &idMapData{task: task, gids: true}), "io": fs.newTaskOwnedInode(task, fs.NextIno(), 0400, newIO(task, isThreadGroup)), "maps": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mapsData{task: task}), + "mem": fs.newMemInode(task, fs.NextIno(), 0400), "mountinfo": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountInfoData{task: task}), "mounts": fs.newTaskOwnedInode(task, fs.NextIno(), 0444, &mountsData{task: task}), "net": fs.newTaskNetDir(task), @@ -89,7 +90,7 @@ func (fs *filesystem) newTaskInode(task *kernel.Task, pidns *kernel.PIDNamespace taskInode := &taskInode{task: task} // Note: credentials are overridden by taskOwnedInode. - taskInode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + taskInode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) taskInode.EnableLeakCheck() inode := &taskOwnedInode{Inode: taskInode, owner: task} @@ -144,7 +145,7 @@ var _ kernfs.Inode = (*taskOwnedInode)(nil) func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linux.FileMode, inode dynamicInode) kernfs.Inode { // Note: credentials are overridden by taskOwnedInode. - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, inode, perm) return &taskOwnedInode{Inode: inode, owner: task} } @@ -152,7 +153,7 @@ func (fs *filesystem) newTaskOwnedInode(task *kernel.Task, ino uint64, perm linu func (fs *filesystem) newTaskOwnedDir(task *kernel.Task, ino uint64, perm linux.FileMode, children map[string]kernfs.Inode) kernfs.Inode { // Note: credentials are overridden by taskOwnedInode. fdOpts := kernfs.GenericDirectoryFDOptions{SeekEnd: kernfs.SeekEndZero} - dir := kernfs.NewStaticDir(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts) + dir := kernfs.NewStaticDir(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm, children, fdOpts) return &taskOwnedInode{Inode: dir, owner: task} } diff --git a/pkg/sentry/fsimpl/proc/task_fds.go b/pkg/sentry/fsimpl/proc/task_fds.go index 2c80ac5c2..d268b44be 100644 --- a/pkg/sentry/fsimpl/proc/task_fds.go +++ b/pkg/sentry/fsimpl/proc/task_fds.go @@ -64,7 +64,7 @@ type fdDir struct { } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *fdDir) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { +func (i *fdDir) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { var fds []int32 i.task.WithMuLocked(func(t *kernel.Task) { if fdTable := t.FDTable(); fdTable != nil { @@ -127,15 +127,15 @@ func (fs *filesystem) newFDDirInode(task *kernel.Task) kernfs.Inode { produceSymlink: true, }, } - inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.EnableLeakCheck() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) return inode } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *fdDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { - return i.fdDir.IterDirents(ctx, cb, offset, relOffset) +func (i *fdDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (int64, error) { + return i.fdDir.IterDirents(ctx, mnt, cb, offset, relOffset) } // Lookup implements kernfs.inodeDirectory.Lookup. @@ -209,7 +209,7 @@ func (fs *filesystem) newFDSymlink(task *kernel.Task, fd int32, ino uint64) kern task: task, fd: fd, } - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -264,7 +264,7 @@ func (fs *filesystem) newFDInfoDirInode(task *kernel.Task) kernfs.Inode { task: task, }, } - inode.InodeAttrs.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.EnableLeakCheck() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) return inode @@ -288,8 +288,8 @@ func (i *fdInfoDirInode) Lookup(ctx context.Context, name string) (kernfs.Inode, } // IterDirents implements Inode.IterDirents. -func (i *fdInfoDirInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { - return i.fdDir.IterDirents(ctx, cb, offset, relOffset) +func (i *fdInfoDirInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, relOffset int64) (newOffset int64, err error) { + return i.fdDir.IterDirents(ctx, mnt, cb, offset, relOffset) } // Open implements kernfs.Inode.Open. diff --git a/pkg/sentry/fsimpl/proc/task_files.go b/pkg/sentry/fsimpl/proc/task_files.go index 79f8b7e9f..ba71d0fde 100644 --- a/pkg/sentry/fsimpl/proc/task_files.go +++ b/pkg/sentry/fsimpl/proc/task_files.go @@ -31,6 +31,7 @@ import ( "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/sentry/usage" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -249,7 +250,7 @@ type commInode struct { func (fs *filesystem) newComm(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode { inode := &commInode{task: task} - inode.DynamicBytesFile.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm) + inode.DynamicBytesFile.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, &commData{task: task}, perm) return inode } @@ -366,6 +367,162 @@ func (d *idMapData) Write(ctx context.Context, src usermem.IOSequence, offset in return int64(srclen), nil } +var _ kernfs.Inode = (*memInode)(nil) + +// memInode implements kernfs.Inode for /proc/[pid]/mem. +// +// +stateify savable +type memInode struct { + kernfs.InodeAttrs + kernfs.InodeNoStatFS + kernfs.InodeNoopRefCount + kernfs.InodeNotDirectory + kernfs.InodeNotSymlink + + task *kernel.Task + locks vfs.FileLocks +} + +func (fs *filesystem) newMemInode(task *kernel.Task, ino uint64, perm linux.FileMode) kernfs.Inode { + // Note: credentials are overridden by taskOwnedInode. + inode := &memInode{task: task} + inode.init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, perm) + return &taskOwnedInode{Inode: inode, owner: task} +} + +func (f *memInode) init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { + if perm&^linux.PermissionsMask != 0 { + panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) + } + f.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm) +} + +// Open implements kernfs.Inode.Open. +func (f *memInode) Open(ctx context.Context, rp *vfs.ResolvingPath, d *kernfs.Dentry, opts vfs.OpenOptions) (*vfs.FileDescription, error) { + // TODO(gvisor.dev/issue/260): Add check for PTRACE_MODE_ATTACH_FSCREDS + // Permission to read this file is governed by PTRACE_MODE_ATTACH_FSCREDS + // Since we dont implement setfsuid/setfsgid we can just use PTRACE_MODE_ATTACH + if !kernel.ContextCanTrace(ctx, f.task, true) { + return nil, syserror.EACCES + } + if err := checkTaskState(f.task); err != nil { + return nil, err + } + fd := &memFD{} + if err := fd.Init(rp.Mount(), d, f, opts.Flags); err != nil { + return nil, err + } + return &fd.vfsfd, nil +} + +// SetStat implements kernfs.Inode.SetStat. +func (*memInode) SetStat(context.Context, *vfs.Filesystem, *auth.Credentials, vfs.SetStatOptions) error { + return syserror.EPERM +} + +var _ vfs.FileDescriptionImpl = (*memFD)(nil) + +// memFD implements vfs.FileDescriptionImpl for /proc/[pid]/mem. +// +// +stateify savable +type memFD struct { + vfsfd vfs.FileDescription + vfs.FileDescriptionDefaultImpl + vfs.LockFD + + inode *memInode + + // mu guards the fields below. + mu sync.Mutex `state:"nosave"` + offset int64 +} + +// Init initializes memFD. +func (fd *memFD) Init(m *vfs.Mount, d *kernfs.Dentry, inode *memInode, flags uint32) error { + fd.LockFD.Init(&inode.locks) + if err := fd.vfsfd.Init(fd, flags, m, d.VFSDentry(), &vfs.FileDescriptionOptions{}); err != nil { + return err + } + fd.inode = inode + return nil +} + +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *memFD) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + fd.mu.Lock() + defer fd.mu.Unlock() + switch whence { + case linux.SEEK_SET: + case linux.SEEK_CUR: + offset += fd.offset + default: + return 0, syserror.EINVAL + } + if offset < 0 { + return 0, syserror.EINVAL + } + fd.offset = offset + return offset, nil +} + +// PRead implements vfs.FileDescriptionImpl.PRead. +func (fd *memFD) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { + if dst.NumBytes() == 0 { + return 0, nil + } + m, err := getMMIncRef(fd.inode.task) + if err != nil { + return 0, nil + } + defer m.DecUsers(ctx) + // Buffer the read data because of MM locks + buf := make([]byte, dst.NumBytes()) + n, readErr := m.CopyIn(ctx, usermem.Addr(offset), buf, usermem.IOOpts{IgnorePermissions: true}) + if n > 0 { + if _, err := dst.CopyOut(ctx, buf[:n]); err != nil { + return 0, syserror.EFAULT + } + return int64(n), nil + } + if readErr != nil { + return 0, syserror.EIO + } + return 0, nil +} + +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *memFD) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + fd.mu.Lock() + n, err := fd.PRead(ctx, dst, fd.offset, opts) + fd.offset += n + fd.mu.Unlock() + return n, err +} + +// Stat implements vfs.FileDescriptionImpl.Stat. +func (fd *memFD) Stat(ctx context.Context, opts vfs.StatOptions) (linux.Statx, error) { + fs := fd.vfsfd.VirtualDentry().Mount().Filesystem() + return fd.inode.Stat(ctx, fs, opts) +} + +// SetStat implements vfs.FileDescriptionImpl.SetStat. +func (fd *memFD) SetStat(context.Context, vfs.SetStatOptions) error { + return syserror.EPERM +} + +// Release implements vfs.FileDescriptionImpl.Release. +func (fd *memFD) Release(context.Context) {} + +// LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. +func (fd *memFD) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { + return fd.Locks().LockPOSIX(ctx, &fd.vfsfd, uid, t, start, length, whence, block) +} + +// UnlockPOSIX implements vfs.FileDescriptionImpl.UnlockPOSIX. +func (fd *memFD) UnlockPOSIX(ctx context.Context, uid fslock.UniqueID, start, length uint64, whence int16) error { + return fd.Locks().UnlockPOSIX(ctx, &fd.vfsfd, uid, start, length, whence) +} + // mapsData implements vfs.DynamicBytesSource for /proc/[pid]/maps. // // +stateify savable @@ -657,7 +814,7 @@ var _ kernfs.Inode = (*exeSymlink)(nil) func (fs *filesystem) newExeSymlink(task *kernel.Task, ino uint64) kernfs.Inode { inode := &exeSymlink{task: task} - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -733,7 +890,7 @@ var _ kernfs.Inode = (*cwdSymlink)(nil) func (fs *filesystem) newCwdSymlink(task *kernel.Task, ino uint64) kernfs.Inode { inode := &cwdSymlink{task: task} - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, linux.ModeSymlink|0777) return inode } @@ -850,7 +1007,7 @@ func (fs *filesystem) newNamespaceSymlink(task *kernel.Task, ino uint64, ns stri inode := &namespaceSymlink{task: task} // Note: credentials are overridden by taskOwnedInode. - inode.Init(task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target) + inode.Init(task, task.Credentials(), linux.UNNAMED_MAJOR, fs.devMinor, ino, target) taskInode := &taskOwnedInode{Inode: inode, owner: task} return taskInode @@ -872,8 +1029,10 @@ func (s *namespaceSymlink) Getlink(ctx context.Context, mnt *vfs.Mount) (vfs.Vir // Create a synthetic inode to represent the namespace. fs := mnt.Filesystem().Impl().(*filesystem) + nsInode := &namespaceInode{} + nsInode.Init(ctx, auth.CredentialsFromContext(ctx), linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), 0444) dentry := &kernfs.Dentry{} - dentry.Init(&fs.Filesystem, &namespaceInode{}) + dentry.Init(&fs.Filesystem, nsInode) vd := vfs.MakeVirtualDentry(mnt, dentry.VFSDentry()) // Only IncRef vd.Mount() because vd.Dentry() already holds a ref of 1. mnt.IncRef() @@ -897,11 +1056,11 @@ type namespaceInode struct { var _ kernfs.Inode = (*namespaceInode)(nil) // Init initializes a namespace inode. -func (i *namespaceInode) Init(creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { +func (i *namespaceInode) Init(ctx context.Context, creds *auth.Credentials, devMajor, devMinor uint32, ino uint64, perm linux.FileMode) { if perm&^linux.PermissionsMask != 0 { panic(fmt.Sprintf("Only permission mask must be set: %x", perm&linux.PermissionsMask)) } - i.InodeAttrs.Init(creds, devMajor, devMinor, ino, linux.ModeRegular|perm) + i.InodeAttrs.Init(ctx, creds, devMajor, devMinor, ino, linux.ModeRegular|perm) } // Open implements kernfs.Inode.Open. diff --git a/pkg/sentry/fsimpl/proc/task_net.go b/pkg/sentry/fsimpl/proc/task_net.go index 3425e8698..5a9ee111f 100644 --- a/pkg/sentry/fsimpl/proc/task_net.go +++ b/pkg/sentry/fsimpl/proc/task_net.go @@ -57,33 +57,33 @@ func (fs *filesystem) newTaskNetDir(task *kernel.Task) kernfs.Inode { // TODO(gvisor.dev/issue/1833): Make sure file contents reflect the task // network namespace. contents = map[string]kernfs.Inode{ - "dev": fs.newInode(root, 0444, &netDevData{stack: stack}), - "snmp": fs.newInode(root, 0444, &netSnmpData{stack: stack}), + "dev": fs.newInode(task, root, 0444, &netDevData{stack: stack}), + "snmp": fs.newInode(task, root, 0444, &netSnmpData{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, if the file contains a header the stub is just the header // otherwise it is an empty file. - "arp": fs.newInode(root, 0444, newStaticFile(arp)), - "netlink": fs.newInode(root, 0444, newStaticFile(netlink)), - "netstat": fs.newInode(root, 0444, &netStatData{}), - "packet": fs.newInode(root, 0444, newStaticFile(packet)), - "protocols": fs.newInode(root, 0444, newStaticFile(protocols)), + "arp": fs.newInode(task, root, 0444, newStaticFile(arp)), + "netlink": fs.newInode(task, root, 0444, newStaticFile(netlink)), + "netstat": fs.newInode(task, root, 0444, &netStatData{}), + "packet": fs.newInode(task, root, 0444, newStaticFile(packet)), + "protocols": fs.newInode(task, root, 0444, newStaticFile(protocols)), // Linux sets psched values to: nsec per usec, psched tick in ns, 1000000, // high res timer ticks per sec (ClockGetres returns 1ns resolution). - "psched": fs.newInode(root, 0444, newStaticFile(psched)), - "ptype": fs.newInode(root, 0444, newStaticFile(ptype)), - "route": fs.newInode(root, 0444, &netRouteData{stack: stack}), - "tcp": fs.newInode(root, 0444, &netTCPData{kernel: k}), - "udp": fs.newInode(root, 0444, &netUDPData{kernel: k}), - "unix": fs.newInode(root, 0444, &netUnixData{kernel: k}), + "psched": fs.newInode(task, root, 0444, newStaticFile(psched)), + "ptype": fs.newInode(task, root, 0444, newStaticFile(ptype)), + "route": fs.newInode(task, root, 0444, &netRouteData{stack: stack}), + "tcp": fs.newInode(task, root, 0444, &netTCPData{kernel: k}), + "udp": fs.newInode(task, root, 0444, &netUDPData{kernel: k}), + "unix": fs.newInode(task, root, 0444, &netUnixData{kernel: k}), } if stack.SupportsIPv6() { - contents["if_inet6"] = fs.newInode(root, 0444, &ifinet6{stack: stack}) - contents["ipv6_route"] = fs.newInode(root, 0444, newStaticFile("")) - contents["tcp6"] = fs.newInode(root, 0444, &netTCP6Data{kernel: k}) - contents["udp6"] = fs.newInode(root, 0444, newStaticFile(upd6)) + contents["if_inet6"] = fs.newInode(task, root, 0444, &ifinet6{stack: stack}) + contents["ipv6_route"] = fs.newInode(task, root, 0444, newStaticFile("")) + contents["tcp6"] = fs.newInode(task, root, 0444, &netTCP6Data{kernel: k}) + contents["udp6"] = fs.newInode(task, root, 0444, newStaticFile(upd6)) } } diff --git a/pkg/sentry/fsimpl/proc/tasks.go b/pkg/sentry/fsimpl/proc/tasks.go index 3259c3732..b81ea14bf 100644 --- a/pkg/sentry/fsimpl/proc/tasks.go +++ b/pkg/sentry/fsimpl/proc/tasks.go @@ -62,19 +62,19 @@ type tasksInode struct { var _ kernfs.Inode = (*tasksInode)(nil) -func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode { +func (fs *filesystem) newTasksInode(ctx context.Context, k *kernel.Kernel, pidns *kernel.PIDNamespace, cgroupControllers map[string]string) *tasksInode { root := auth.NewRootCredentials(pidns.UserNamespace()) contents := map[string]kernfs.Inode{ - "cpuinfo": fs.newInode(root, 0444, newStaticFileSetStat(cpuInfoData(k))), - "filesystems": fs.newInode(root, 0444, &filesystemsData{}), - "loadavg": fs.newInode(root, 0444, &loadavgData{}), - "sys": fs.newSysDir(root, k), - "meminfo": fs.newInode(root, 0444, &meminfoData{}), - "mounts": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"), - "net": kernfs.NewStaticSymlink(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"), - "stat": fs.newInode(root, 0444, &statData{}), - "uptime": fs.newInode(root, 0444, &uptimeData{}), - "version": fs.newInode(root, 0444, &versionData{}), + "cpuinfo": fs.newInode(ctx, root, 0444, newStaticFileSetStat(cpuInfoData(k))), + "filesystems": fs.newInode(ctx, root, 0444, &filesystemsData{}), + "loadavg": fs.newInode(ctx, root, 0444, &loadavgData{}), + "sys": fs.newSysDir(ctx, root, k), + "meminfo": fs.newInode(ctx, root, 0444, &meminfoData{}), + "mounts": kernfs.NewStaticSymlink(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/mounts"), + "net": kernfs.NewStaticSymlink(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), "self/net"), + "stat": fs.newInode(ctx, root, 0444, &statData{}), + "uptime": fs.newInode(ctx, root, 0444, &uptimeData{}), + "version": fs.newInode(ctx, root, 0444, &versionData{}), } inode := &tasksInode{ @@ -82,7 +82,7 @@ func (fs *filesystem) newTasksInode(k *kernel.Kernel, pidns *kernel.PIDNamespace fs: fs, cgroupControllers: cgroupControllers, } - inode.InodeAttrs.Init(root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) + inode.InodeAttrs.Init(ctx, root, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0555) inode.EnableLeakCheck() inode.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) @@ -106,9 +106,9 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err // If it failed to parse, check if it's one of the special handled files. switch name { case selfName: - return i.newSelfSymlink(root), nil + return i.newSelfSymlink(ctx, root), nil case threadSelfName: - return i.newThreadSelfSymlink(root), nil + return i.newThreadSelfSymlink(ctx, root), nil } return nil, syserror.ENOENT } @@ -122,7 +122,7 @@ func (i *tasksInode) Lookup(ctx context.Context, name string) (kernfs.Inode, err } // IterDirents implements kernfs.inodeDirectory.IterDirents. -func (i *tasksInode) IterDirents(ctx context.Context, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) { +func (i *tasksInode) IterDirents(ctx context.Context, mnt *vfs.Mount, cb vfs.IterDirentsCallback, offset, _ int64) (int64, error) { // fs/proc/internal.h: #define FIRST_PROCESS_ENTRY 256 const FIRST_PROCESS_ENTRY = 256 diff --git a/pkg/sentry/fsimpl/proc/tasks_files.go b/pkg/sentry/fsimpl/proc/tasks_files.go index 07c27cdd9..01b7a6678 100644 --- a/pkg/sentry/fsimpl/proc/tasks_files.go +++ b/pkg/sentry/fsimpl/proc/tasks_files.go @@ -43,9 +43,9 @@ type selfSymlink struct { var _ kernfs.Inode = (*selfSymlink)(nil) -func (i *tasksInode) newSelfSymlink(creds *auth.Credentials) kernfs.Inode { +func (i *tasksInode) newSelfSymlink(ctx context.Context, creds *auth.Credentials) kernfs.Inode { inode := &selfSymlink{pidns: i.pidns} - inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) + inode.Init(ctx, creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) return inode } @@ -84,9 +84,9 @@ type threadSelfSymlink struct { var _ kernfs.Inode = (*threadSelfSymlink)(nil) -func (i *tasksInode) newThreadSelfSymlink(creds *auth.Credentials) kernfs.Inode { +func (i *tasksInode) newThreadSelfSymlink(ctx context.Context, creds *auth.Credentials) kernfs.Inode { inode := &threadSelfSymlink{pidns: i.pidns} - inode.Init(creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) + inode.Init(ctx, creds, linux.UNNAMED_MAJOR, i.fs.devMinor, i.fs.NextIno(), linux.ModeSymlink|0777) return inode } diff --git a/pkg/sentry/fsimpl/proc/tasks_sys.go b/pkg/sentry/fsimpl/proc/tasks_sys.go index 95420368d..7c7afdcfa 100644 --- a/pkg/sentry/fsimpl/proc/tasks_sys.go +++ b/pkg/sentry/fsimpl/proc/tasks_sys.go @@ -40,93 +40,93 @@ const ( ) // newSysDir returns the dentry corresponding to /proc/sys directory. -func (fs *filesystem) newSysDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { - return fs.newStaticDir(root, map[string]kernfs.Inode{ - "kernel": fs.newStaticDir(root, map[string]kernfs.Inode{ - "hostname": fs.newInode(root, 0444, &hostnameData{}), - "shmall": fs.newInode(root, 0444, shmData(linux.SHMALL)), - "shmmax": fs.newInode(root, 0444, shmData(linux.SHMMAX)), - "shmmni": fs.newInode(root, 0444, shmData(linux.SHMMNI)), +func (fs *filesystem) newSysDir(ctx context.Context, root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { + return fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "kernel": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "hostname": fs.newInode(ctx, root, 0444, &hostnameData{}), + "shmall": fs.newInode(ctx, root, 0444, shmData(linux.SHMALL)), + "shmmax": fs.newInode(ctx, root, 0444, shmData(linux.SHMMAX)), + "shmmni": fs.newInode(ctx, root, 0444, shmData(linux.SHMMNI)), }), - "vm": fs.newStaticDir(root, map[string]kernfs.Inode{ - "mmap_min_addr": fs.newInode(root, 0444, &mmapMinAddrData{k: k}), - "overcommit_memory": fs.newInode(root, 0444, newStaticFile("0\n")), + "vm": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "mmap_min_addr": fs.newInode(ctx, root, 0444, &mmapMinAddrData{k: k}), + "overcommit_memory": fs.newInode(ctx, root, 0444, newStaticFile("0\n")), }), - "net": fs.newSysNetDir(root, k), + "net": fs.newSysNetDir(ctx, root, k), }) } // newSysNetDir returns the dentry corresponding to /proc/sys/net directory. -func (fs *filesystem) newSysNetDir(root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { +func (fs *filesystem) newSysNetDir(ctx context.Context, root *auth.Credentials, k *kernel.Kernel) kernfs.Inode { var contents map[string]kernfs.Inode // TODO(gvisor.dev/issue/1833): Support for using the network stack in the // network namespace of the calling process. if stack := k.RootNetworkNamespace().Stack(); stack != nil { contents = map[string]kernfs.Inode{ - "ipv4": fs.newStaticDir(root, map[string]kernfs.Inode{ - "tcp_recovery": fs.newInode(root, 0644, &tcpRecoveryData{stack: stack}), - "tcp_rmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}), - "tcp_sack": fs.newInode(root, 0644, &tcpSackData{stack: stack}), - "tcp_wmem": fs.newInode(root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}), - "ip_forward": fs.newInode(root, 0444, &ipForwarding{stack: stack}), + "ipv4": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "tcp_recovery": fs.newInode(ctx, root, 0644, &tcpRecoveryData{stack: stack}), + "tcp_rmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpRMem}), + "tcp_sack": fs.newInode(ctx, root, 0644, &tcpSackData{stack: stack}), + "tcp_wmem": fs.newInode(ctx, root, 0644, &tcpMemData{stack: stack, dir: tcpWMem}), + "ip_forward": fs.newInode(ctx, root, 0444, &ipForwarding{stack: stack}), // The following files are simple stubs until they are implemented in // netstack, most of these files are configuration related. We use the // value closest to the actual netstack behavior or any empty file, all // of these files will have mode 0444 (read-only for all users). - "ip_local_port_range": fs.newInode(root, 0444, newStaticFile("16000 65535")), - "ip_local_reserved_ports": fs.newInode(root, 0444, newStaticFile("")), - "ipfrag_time": fs.newInode(root, 0444, newStaticFile("30")), - "ip_nonlocal_bind": fs.newInode(root, 0444, newStaticFile("0")), - "ip_no_pmtu_disc": fs.newInode(root, 0444, newStaticFile("1")), + "ip_local_port_range": fs.newInode(ctx, root, 0444, newStaticFile("16000 65535")), + "ip_local_reserved_ports": fs.newInode(ctx, root, 0444, newStaticFile("")), + "ipfrag_time": fs.newInode(ctx, root, 0444, newStaticFile("30")), + "ip_nonlocal_bind": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "ip_no_pmtu_disc": fs.newInode(ctx, root, 0444, newStaticFile("1")), // tcp_allowed_congestion_control tell the user what they are able to // do as an unprivledged process so we leave it empty. - "tcp_allowed_congestion_control": fs.newInode(root, 0444, newStaticFile("")), - "tcp_available_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")), - "tcp_congestion_control": fs.newInode(root, 0444, newStaticFile("reno")), + "tcp_allowed_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("")), + "tcp_available_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("reno")), + "tcp_congestion_control": fs.newInode(ctx, root, 0444, newStaticFile("reno")), // Many of the following stub files are features netstack doesn't // support. The unsupported features return "0" to indicate they are // disabled. - "tcp_base_mss": fs.newInode(root, 0444, newStaticFile("1280")), - "tcp_dsack": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_early_retrans": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_fack": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_fastopen": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_fastopen_key": fs.newInode(root, 0444, newStaticFile("")), - "tcp_invalid_ratelimit": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_keepalive_intvl": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_keepalive_probes": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_keepalive_time": fs.newInode(root, 0444, newStaticFile("7200")), - "tcp_mtu_probing": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_no_metrics_save": fs.newInode(root, 0444, newStaticFile("1")), - "tcp_probe_interval": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_probe_threshold": fs.newInode(root, 0444, newStaticFile("0")), - "tcp_retries1": fs.newInode(root, 0444, newStaticFile("3")), - "tcp_retries2": fs.newInode(root, 0444, newStaticFile("15")), - "tcp_rfc1337": fs.newInode(root, 0444, newStaticFile("1")), - "tcp_slow_start_after_idle": fs.newInode(root, 0444, newStaticFile("1")), - "tcp_synack_retries": fs.newInode(root, 0444, newStaticFile("5")), - "tcp_syn_retries": fs.newInode(root, 0444, newStaticFile("3")), - "tcp_timestamps": fs.newInode(root, 0444, newStaticFile("1")), + "tcp_base_mss": fs.newInode(ctx, root, 0444, newStaticFile("1280")), + "tcp_dsack": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_early_retrans": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_fack": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_fastopen": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_fastopen_key": fs.newInode(ctx, root, 0444, newStaticFile("")), + "tcp_invalid_ratelimit": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_keepalive_intvl": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_keepalive_probes": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_keepalive_time": fs.newInode(ctx, root, 0444, newStaticFile("7200")), + "tcp_mtu_probing": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_no_metrics_save": fs.newInode(ctx, root, 0444, newStaticFile("1")), + "tcp_probe_interval": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_probe_threshold": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "tcp_retries1": fs.newInode(ctx, root, 0444, newStaticFile("3")), + "tcp_retries2": fs.newInode(ctx, root, 0444, newStaticFile("15")), + "tcp_rfc1337": fs.newInode(ctx, root, 0444, newStaticFile("1")), + "tcp_slow_start_after_idle": fs.newInode(ctx, root, 0444, newStaticFile("1")), + "tcp_synack_retries": fs.newInode(ctx, root, 0444, newStaticFile("5")), + "tcp_syn_retries": fs.newInode(ctx, root, 0444, newStaticFile("3")), + "tcp_timestamps": fs.newInode(ctx, root, 0444, newStaticFile("1")), }), - "core": fs.newStaticDir(root, map[string]kernfs.Inode{ - "default_qdisc": fs.newInode(root, 0444, newStaticFile("pfifo_fast")), - "message_burst": fs.newInode(root, 0444, newStaticFile("10")), - "message_cost": fs.newInode(root, 0444, newStaticFile("5")), - "optmem_max": fs.newInode(root, 0444, newStaticFile("0")), - "rmem_default": fs.newInode(root, 0444, newStaticFile("212992")), - "rmem_max": fs.newInode(root, 0444, newStaticFile("212992")), - "somaxconn": fs.newInode(root, 0444, newStaticFile("128")), - "wmem_default": fs.newInode(root, 0444, newStaticFile("212992")), - "wmem_max": fs.newInode(root, 0444, newStaticFile("212992")), + "core": fs.newStaticDir(ctx, root, map[string]kernfs.Inode{ + "default_qdisc": fs.newInode(ctx, root, 0444, newStaticFile("pfifo_fast")), + "message_burst": fs.newInode(ctx, root, 0444, newStaticFile("10")), + "message_cost": fs.newInode(ctx, root, 0444, newStaticFile("5")), + "optmem_max": fs.newInode(ctx, root, 0444, newStaticFile("0")), + "rmem_default": fs.newInode(ctx, root, 0444, newStaticFile("212992")), + "rmem_max": fs.newInode(ctx, root, 0444, newStaticFile("212992")), + "somaxconn": fs.newInode(ctx, root, 0444, newStaticFile("128")), + "wmem_default": fs.newInode(ctx, root, 0444, newStaticFile("212992")), + "wmem_max": fs.newInode(ctx, root, 0444, newStaticFile("212992")), }), } } - return fs.newStaticDir(root, contents) + return fs.newStaticDir(ctx, root, contents) } // mmapMinAddrData implements vfs.DynamicBytesSource for diff --git a/pkg/sentry/fsimpl/proc/tasks_test.go b/pkg/sentry/fsimpl/proc/tasks_test.go index 2582ababd..7ee6227a9 100644 --- a/pkg/sentry/fsimpl/proc/tasks_test.go +++ b/pkg/sentry/fsimpl/proc/tasks_test.go @@ -77,6 +77,7 @@ var ( "gid_map": linux.DT_REG, "io": linux.DT_REG, "maps": linux.DT_REG, + "mem": linux.DT_REG, "mountinfo": linux.DT_REG, "mounts": linux.DT_REG, "net": linux.DT_DIR, diff --git a/pkg/sentry/fsimpl/sockfs/sockfs.go b/pkg/sentry/fsimpl/sockfs/sockfs.go index cf91ea36c..fda1fa942 100644 --- a/pkg/sentry/fsimpl/sockfs/sockfs.go +++ b/pkg/sentry/fsimpl/sockfs/sockfs.go @@ -108,13 +108,13 @@ func (i *inode) StatFS(ctx context.Context, fs *vfs.Filesystem) (linux.Statfs, e // NewDentry constructs and returns a sockfs dentry. // // Preconditions: mnt.Filesystem() must have been returned by NewFilesystem(). -func NewDentry(creds *auth.Credentials, mnt *vfs.Mount) *vfs.Dentry { +func NewDentry(ctx context.Context, mnt *vfs.Mount) *vfs.Dentry { fs := mnt.Filesystem().Impl().(*filesystem) // File mode matches net/socket.c:sock_alloc. filemode := linux.FileMode(linux.S_IFSOCK | 0600) i := &inode{} - i.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode) + i.InodeAttrs.Init(ctx, auth.CredentialsFromContext(ctx), linux.UNNAMED_MAJOR, fs.devMinor, fs.Filesystem.NextIno(), filemode) d := &kernfs.Dentry{} d.Init(&fs.Filesystem, i) diff --git a/pkg/sentry/fsimpl/sys/BUILD b/pkg/sentry/fsimpl/sys/BUILD index 906cd52cb..09043b572 100644 --- a/pkg/sentry/fsimpl/sys/BUILD +++ b/pkg/sentry/fsimpl/sys/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "dir_refs.go", package = "sys", prefix = "dir", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "dir", }, @@ -28,6 +28,7 @@ go_library( "//pkg/coverage", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sentry/arch", "//pkg/sentry/fsimpl/kernfs", "//pkg/sentry/kernel", diff --git a/pkg/sentry/fsimpl/sys/kcov.go b/pkg/sentry/fsimpl/sys/kcov.go index 31a361029..b13f141a8 100644 --- a/pkg/sentry/fsimpl/sys/kcov.go +++ b/pkg/sentry/fsimpl/sys/kcov.go @@ -29,7 +29,7 @@ import ( func (fs *filesystem) newKcovFile(ctx context.Context, creds *auth.Credentials) kernfs.Inode { k := &kcovInode{} - k.InodeAttrs.Init(creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600) + k.InodeAttrs.Init(ctx, creds, 0, 0, fs.NextIno(), linux.S_IFREG|0600) return k } diff --git a/pkg/sentry/fsimpl/sys/sys.go b/pkg/sentry/fsimpl/sys/sys.go index 1ad679830..506a2a0f0 100644 --- a/pkg/sentry/fsimpl/sys/sys.go +++ b/pkg/sentry/fsimpl/sys/sys.go @@ -18,6 +18,7 @@ package sys import ( "bytes" "fmt" + "strconv" "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" @@ -29,9 +30,12 @@ import ( "gvisor.dev/gvisor/pkg/syserror" ) -// Name is the default filesystem name. -const Name = "sysfs" -const defaultSysDirMode = linux.FileMode(0755) +const ( + // Name is the default filesystem name. + Name = "sysfs" + defaultSysDirMode = linux.FileMode(0755) + defaultMaxCachedDentries = uint64(1000) +) // FilesystemType implements vfs.FilesystemType. // @@ -62,31 +66,43 @@ func (fsType FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt return nil, nil, err } + mopts := vfs.GenericParseMountOptions(opts.Data) + maxCachedDentries := defaultMaxCachedDentries + if str, ok := mopts["dentry_cache_limit"]; ok { + delete(mopts, "dentry_cache_limit") + maxCachedDentries, err = strconv.ParseUint(str, 10, 64) + if err != nil { + ctx.Warningf("sys.FilesystemType.GetFilesystem: invalid dentry cache limit: dentry_cache_limit=%s", str) + return nil, nil, syserror.EINVAL + } + } + fs := &filesystem{ devMinor: devMinor, } + fs.MaxCachedDentries = maxCachedDentries fs.VFSFilesystem().Init(vfsObj, &fsType, fs) - root := fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ - "block": fs.newDir(creds, defaultSysDirMode, nil), - "bus": fs.newDir(creds, defaultSysDirMode, nil), - "class": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ - "power_supply": fs.newDir(creds, defaultSysDirMode, nil), + root := fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ + "block": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "bus": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "class": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ + "power_supply": fs.newDir(ctx, creds, defaultSysDirMode, nil), }), - "dev": fs.newDir(creds, defaultSysDirMode, nil), - "devices": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ - "system": fs.newDir(creds, defaultSysDirMode, map[string]kernfs.Inode{ + "dev": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "devices": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ + "system": fs.newDir(ctx, creds, defaultSysDirMode, map[string]kernfs.Inode{ "cpu": cpuDir(ctx, fs, creds), }), }), - "firmware": fs.newDir(creds, defaultSysDirMode, nil), - "fs": fs.newDir(creds, defaultSysDirMode, nil), + "firmware": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "fs": fs.newDir(ctx, creds, defaultSysDirMode, nil), "kernel": kernelDir(ctx, fs, creds), - "module": fs.newDir(creds, defaultSysDirMode, nil), - "power": fs.newDir(creds, defaultSysDirMode, nil), + "module": fs.newDir(ctx, creds, defaultSysDirMode, nil), + "power": fs.newDir(ctx, creds, defaultSysDirMode, nil), }) var rootD kernfs.Dentry - rootD.Init(&fs.Filesystem, root) + rootD.InitRoot(&fs.Filesystem, root) return fs.VFSFilesystem(), rootD.VFSDentry(), nil } @@ -94,14 +110,14 @@ func cpuDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs k := kernel.KernelFromContext(ctx) maxCPUCores := k.ApplicationCores() children := map[string]kernfs.Inode{ - "online": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), - "possible": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), - "present": fs.newCPUFile(creds, maxCPUCores, linux.FileMode(0444)), + "online": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)), + "possible": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)), + "present": fs.newCPUFile(ctx, creds, maxCPUCores, linux.FileMode(0444)), } for i := uint(0); i < maxCPUCores; i++ { - children[fmt.Sprintf("cpu%d", i)] = fs.newDir(creds, linux.FileMode(0555), nil) + children[fmt.Sprintf("cpu%d", i)] = fs.newDir(ctx, creds, linux.FileMode(0555), nil) } - return fs.newDir(creds, defaultSysDirMode, children) + return fs.newDir(ctx, creds, defaultSysDirMode, children) } func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) kernfs.Inode { @@ -111,12 +127,12 @@ func kernelDir(ctx context.Context, fs *filesystem, creds *auth.Credentials) ker var children map[string]kernfs.Inode if coverage.KcovAvailable() { children = map[string]kernfs.Inode{ - "debug": fs.newDir(creds, linux.FileMode(0700), map[string]kernfs.Inode{ + "debug": fs.newDir(ctx, creds, linux.FileMode(0700), map[string]kernfs.Inode{ "kcov": fs.newKcovFile(ctx, creds), }), } } - return fs.newDir(creds, defaultSysDirMode, children) + return fs.newDir(ctx, creds, defaultSysDirMode, children) } // Release implements vfs.FilesystemImpl.Release. @@ -140,9 +156,9 @@ type dir struct { locks vfs.FileLocks } -func (fs *filesystem) newDir(creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { +func (fs *filesystem) newDir(ctx context.Context, creds *auth.Credentials, mode linux.FileMode, contents map[string]kernfs.Inode) kernfs.Inode { d := &dir{} - d.InodeAttrs.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) + d.InodeAttrs.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), linux.ModeDirectory|0755) d.OrderedChildren.Init(kernfs.OrderedChildrenOptions{}) d.EnableLeakCheck() d.IncLinks(d.OrderedChildren.Populate(contents)) @@ -191,9 +207,9 @@ func (c *cpuFile) Generate(ctx context.Context, buf *bytes.Buffer) error { return nil } -func (fs *filesystem) newCPUFile(creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode { +func (fs *filesystem) newCPUFile(ctx context.Context, creds *auth.Credentials, maxCores uint, mode linux.FileMode) kernfs.Inode { c := &cpuFile{maxCores: maxCores} - c.DynamicBytesFile.Init(creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode) + c.DynamicBytesFile.Init(ctx, creds, linux.UNNAMED_MAJOR, fs.devMinor, fs.NextIno(), c, mode) return c } diff --git a/pkg/sentry/fsimpl/tmpfs/BUILD b/pkg/sentry/fsimpl/tmpfs/BUILD index 5cd428d64..fe520b6fd 100644 --- a/pkg/sentry/fsimpl/tmpfs/BUILD +++ b/pkg/sentry/fsimpl/tmpfs/BUILD @@ -31,7 +31,7 @@ go_template_instance( out = "inode_refs.go", package = "tmpfs", prefix = "inode", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "inode", }, @@ -48,6 +48,7 @@ go_library( "inode_refs.go", "named_pipe.go", "regular_file.go", + "save_restore.go", "socket_file.go", "symlink.go", "tmpfs.go", @@ -60,6 +61,7 @@ go_library( "//pkg/fspath", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/fsimpl/tmpfs/named_pipe.go b/pkg/sentry/fsimpl/tmpfs/named_pipe.go index d772db9e9..57e7b57b0 100644 --- a/pkg/sentry/fsimpl/tmpfs/named_pipe.go +++ b/pkg/sentry/fsimpl/tmpfs/named_pipe.go @@ -18,7 +18,6 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/sentry/kernel/pipe" - "gvisor.dev/gvisor/pkg/usermem" ) // +stateify savable @@ -32,7 +31,7 @@ type namedPipe struct { // * fs.mu must be locked. // * rp.Mount().CheckBeginWrite() has been called successfully. func (fs *filesystem) newNamedPipe(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode { - file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize, usermem.PageSize)} + file := &namedPipe{pipe: pipe.NewVFSPipe(true /* isNamed */, pipe.DefaultPipeSize)} file.inode.init(file, fs, kuid, kgid, linux.S_IFIFO|mode) file.inode.nlink = 1 // Only the parent has a link. return &file.inode diff --git a/pkg/sentry/fsimpl/tmpfs/regular_file.go b/pkg/sentry/fsimpl/tmpfs/regular_file.go index ce4e3eda7..98680fde9 100644 --- a/pkg/sentry/fsimpl/tmpfs/regular_file.go +++ b/pkg/sentry/fsimpl/tmpfs/regular_file.go @@ -42,7 +42,7 @@ type regularFile struct { inode inode // memFile is a platform.File used to allocate pages to this regularFile. - memFile *pgalloc.MemoryFile + memFile *pgalloc.MemoryFile `state:"nosave"` // memoryUsageKind is the memory accounting category under which pages backing // this regularFile's contents are accounted. @@ -92,7 +92,7 @@ type regularFile struct { func (fs *filesystem) newRegularFile(kuid auth.KUID, kgid auth.KGID, mode linux.FileMode) *inode { file := ®ularFile{ - memFile: fs.memFile, + memFile: fs.mfp.MemoryFile(), memoryUsageKind: usage.Tmpfs, seals: linux.F_SEAL_SEAL, } diff --git a/pkg/sentry/fsimpl/tmpfs/save_restore.go b/pkg/sentry/fsimpl/tmpfs/save_restore.go new file mode 100644 index 000000000..b27f75cc2 --- /dev/null +++ b/pkg/sentry/fsimpl/tmpfs/save_restore.go @@ -0,0 +1,20 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package tmpfs + +// afterLoad is called by stateify. +func (rf *regularFile) afterLoad() { + rf.memFile = rf.inode.fs.mfp.MemoryFile() +} diff --git a/pkg/sentry/fsimpl/tmpfs/tmpfs.go b/pkg/sentry/fsimpl/tmpfs/tmpfs.go index e2a0aac69..4ce859d57 100644 --- a/pkg/sentry/fsimpl/tmpfs/tmpfs.go +++ b/pkg/sentry/fsimpl/tmpfs/tmpfs.go @@ -61,8 +61,9 @@ type FilesystemType struct{} type filesystem struct { vfsfs vfs.Filesystem - // memFile is used to allocate pages to for regular files. - memFile *pgalloc.MemoryFile + // mfp is used to allocate memory that stores regular file contents. mfp is + // immutable. + mfp pgalloc.MemoryFileProvider // clock is a realtime clock used to set timestamps in file operations. clock time.Clock @@ -106,8 +107,8 @@ type FilesystemOpts struct { // GetFilesystem implements vfs.FilesystemType.GetFilesystem. func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.VirtualFilesystem, creds *auth.Credentials, _ string, opts vfs.GetFilesystemOptions) (*vfs.Filesystem, *vfs.Dentry, error) { - memFileProvider := pgalloc.MemoryFileProviderFromContext(ctx) - if memFileProvider == nil { + mfp := pgalloc.MemoryFileProviderFromContext(ctx) + if mfp == nil { panic("MemoryFileProviderFromContext returned nil") } @@ -181,7 +182,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt } clock := time.RealtimeClockFromContext(ctx) fs := filesystem{ - memFile: memFileProvider.MemoryFile(), + mfp: mfp, clock: clock, devMinor: devMinor, } diff --git a/pkg/sentry/fsimpl/verity/BUILD b/pkg/sentry/fsimpl/verity/BUILD index 0ca750281..e265be0ee 100644 --- a/pkg/sentry/fsimpl/verity/BUILD +++ b/pkg/sentry/fsimpl/verity/BUILD @@ -6,6 +6,7 @@ go_library( name = "verity", srcs = [ "filesystem.go", + "save_restore.go", "verity.go", ], visibility = ["//pkg/sentry:internal"], @@ -15,6 +16,7 @@ go_library( "//pkg/fspath", "//pkg/marshal/primitive", "//pkg/merkletree", + "//pkg/refsvfs2", "//pkg/sentry/arch", "//pkg/sentry/fs/lock", "//pkg/sentry/kernel", @@ -38,10 +40,12 @@ go_test( "//pkg/context", "//pkg/fspath", "//pkg/sentry/arch", + "//pkg/sentry/fsimpl/testutil", "//pkg/sentry/fsimpl/tmpfs", + "//pkg/sentry/kernel", "//pkg/sentry/kernel/auth", - "//pkg/sentry/kernel/contexttest", "//pkg/sentry/vfs", + "//pkg/syserror", "//pkg/usermem", ], ) diff --git a/pkg/sentry/fsimpl/verity/filesystem.go b/pkg/sentry/fsimpl/verity/filesystem.go index 03da505e1..4e8d63d51 100644 --- a/pkg/sentry/fsimpl/verity/filesystem.go +++ b/pkg/sentry/fsimpl/verity/filesystem.go @@ -192,7 +192,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleOffsetInParentXattr, childPath, err)) } if err != nil { return nil, err @@ -201,7 +201,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // unexpected modifications to the file system. offset, err := strconv.Atoi(off) if err != nil { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleOffsetInParentXattr, childPath, err)) } // Open parent Merkle tree file to read and verify child's hash. @@ -215,7 +215,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // The parent Merkle tree file should have been created. If it's // missing, it indicates an unexpected modification to the file system. if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to open parent Merkle file for %s: %v", childPath, err)) } if err != nil { return nil, err @@ -233,7 +233,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // contains the expected xattrs. If the file or the xattr does not // exist, it indicates unexpected modifications to the file system. if err == syserror.ENOENT || err == syserror.ENODATA { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return nil, err @@ -243,7 +243,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // unexpected modifications to the file system. parentSize, err := strconv.Atoi(dataSize) if err != nil { - return nil, alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } fdReader := vfs.FileReadWriteSeeker{ @@ -256,7 +256,7 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de Start: parent.lowerVD, }, &vfs.StatOptions{}) if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to get parent stat for %s: %v", childPath, err)) } if err != nil { return nil, err @@ -267,20 +267,22 @@ func (fs *filesystem) verifyChild(ctx context.Context, parent *dentry, child *de // Verify returns with success. var buf bytes.Buffer if _, err := merkletree.Verify(&merkletree.VerifyParams{ - Out: &buf, - File: &fdReader, - Tree: &fdReader, - Size: int64(parentSize), - Name: parent.name, - Mode: uint32(parentStat.Mode), - UID: parentStat.UID, - GID: parentStat.GID, + Out: &buf, + File: &fdReader, + Tree: &fdReader, + Size: int64(parentSize), + Name: parent.name, + Mode: uint32(parentStat.Mode), + UID: parentStat.UID, + GID: parentStat.GID, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: fs.alg.toLinuxHashAlg(), ReadOffset: int64(offset), - ReadSize: int64(merkletree.DigestSize()), + ReadSize: int64(merkletree.DigestSize(fs.alg.toLinuxHashAlg())), Expected: parent.hash, DataAndTreeInSameFile: true, }); err != nil && err != io.EOF { - return nil, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification for %s failed: %v", childPath, err)) + return nil, alertIntegrityViolation(fmt.Sprintf("Verification for %s failed: %v", childPath, err)) } // Cache child hash when it's verified the first time. @@ -312,7 +314,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat Flags: linux.O_RDONLY, }) if err == syserror.ENOENT { - return alertIntegrityViolation(err, fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Failed to open merkle file for %s: %v", childPath, err)) } if err != nil { return err @@ -324,7 +326,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat }) if err == syserror.ENODATA { - return alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s for merkle file of %s: %v", merkleSizeXattr, childPath, err)) } if err != nil { return err @@ -332,7 +334,7 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat size, err := strconv.Atoi(merkleSize) if err != nil { - return alertIntegrityViolation(syserror.EINVAL, fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s for %s to int: %v", merkleSizeXattr, childPath, err)) } fdReader := vfs.FileReadWriteSeeker{ @@ -342,14 +344,16 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat var buf bytes.Buffer params := &merkletree.VerifyParams{ - Out: &buf, - Tree: &fdReader, - Size: int64(size), - Name: d.name, - Mode: uint32(stat.Mode), - UID: stat.UID, - GID: stat.GID, - ReadOffset: 0, + Out: &buf, + Tree: &fdReader, + Size: int64(size), + Name: d.name, + Mode: uint32(stat.Mode), + UID: stat.UID, + GID: stat.GID, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: fs.alg.toLinuxHashAlg(), + ReadOffset: 0, // Set read size to 0 so only the metadata is verified. ReadSize: 0, Expected: d.hash, @@ -360,17 +364,57 @@ func (fs *filesystem) verifyStat(ctx context.Context, d *dentry, stat linux.Stat } if _, err := merkletree.Verify(params); err != nil && err != io.EOF { - return alertIntegrityViolation(err, fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) + return alertIntegrityViolation(fmt.Sprintf("Verification stat for %s failed: %v", childPath, err)) } d.mode = uint32(stat.Mode) d.uid = stat.UID d.gid = stat.GID + d.size = uint32(size) return nil } // Preconditions: fs.renameMu must be locked. d.dirMu must be locked. func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name string, ds **[]*dentry) (*dentry, error) { if child, ok := parent.children[name]; ok { + // If verity is enabled on child, we should check again whether + // the file and the corresponding Merkle tree are as expected, + // in order to catch deletion/renaming after the last time it's + // accessed. + if child.verityEnabled() { + vfsObj := fs.vfsfs.VirtualFilesystem() + // Get the path to the child dentry. This is only used + // to provide path information in failure case. + path, err := vfsObj.PathnameWithDeleted(ctx, child.fs.rootDentry.lowerVD, child.lowerVD) + if err != nil { + return nil, err + } + + childVD, err := parent.getLowerAt(ctx, vfsObj, name) + if err == syserror.ENOENT { + // The file was previously accessed. If the + // file does not exist now, it indicates an + // unexpected modification to the file system. + return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", path)) + } + if err != nil { + return nil, err + } + defer childVD.DecRef(ctx) + + childMerkleVD, err := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) + // The Merkle tree file was previous accessed. If it + // does not exist now, it indicates an unexpected + // modification to the file system. + if err == syserror.ENOENT { + return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", path)) + } + if err != nil { + return nil, err + } + + defer childMerkleVD.DecRef(ctx) + } + // If enabling verification on files/directories is not allowed // during runtime, all cached children are already verified. If // runtime enable is allowed and the parent directory is @@ -418,13 +462,7 @@ func (fs *filesystem) getChildLocked(ctx context.Context, parent *dentry, name s func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, name string) (*dentry, error) { vfsObj := fs.vfsfs.VirtualFilesystem() - childFilename := fspath.Parse(name) - childVD, childErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: childFilename, - }, &vfs.GetDentryOptions{}) - + childVD, childErr := parent.getLowerAt(ctx, vfsObj, name) // We will handle ENOENT separately, as it may indicate unexpected // modifications to the file system, and may cause a sentry panic. if childErr != nil && childErr != syserror.ENOENT { @@ -437,13 +475,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, defer childVD.DecRef(ctx) } - childMerkleFilename := merklePrefix + name - childMerkleVD, childMerkleErr := vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: fspath.Parse(childMerkleFilename), - }, &vfs.GetDentryOptions{}) - + childMerkleVD, childMerkleErr := parent.getLowerAt(ctx, vfsObj, merklePrefix+name) // We will handle ENOENT separately, as it may indicate unexpected // modifications to the file system, and may cause a sentry panic. if childMerkleErr != nil && childMerkleErr != syserror.ENOENT { @@ -472,7 +504,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // corresponding Merkle tree is found. This indicates an // unexpected modification to the file system that // removed/renamed the child. - return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name)) + return nil, alertIntegrityViolation(fmt.Sprintf("Target file %s is expected but missing", parentPath+"/"+name)) } else if childErr == nil && childMerkleErr == syserror.ENOENT { // If in allowRuntimeEnable mode, and the Merkle tree file is // not created yet, we create an empty Merkle tree file, so that @@ -488,7 +520,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, childMerkleFD, err := vfsObj.OpenAt(ctx, fs.creds, &vfs.PathOperation{ Root: parent.lowerVD, Start: parent.lowerVD, - Path: fspath.Parse(childMerkleFilename), + Path: fspath.Parse(merklePrefix + name), }, &vfs.OpenOptions{ Flags: linux.O_RDWR | linux.O_CREAT, Mode: 0644, @@ -497,11 +529,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, return nil, err } childMerkleFD.DecRef(ctx) - childMerkleVD, err = vfsObj.GetDentryAt(ctx, fs.creds, &vfs.PathOperation{ - Root: parent.lowerVD, - Start: parent.lowerVD, - Path: fspath.Parse(childMerkleFilename), - }, &vfs.GetDentryOptions{}) + childMerkleVD, err = parent.getLowerAt(ctx, vfsObj, merklePrefix+name) if err != nil { return nil, err } @@ -509,7 +537,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // If runtime enable is not allowed. This indicates an // unexpected modification to the file system that // removed/renamed the Merkle tree file. - return nil, alertIntegrityViolation(childMerkleErr, fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name)) + return nil, alertIntegrityViolation(fmt.Sprintf("Expected Merkle file for target %s but none found", parentPath+"/"+name)) } } else if childErr == syserror.ENOENT && childMerkleErr == syserror.ENOENT { // Both the child and the corresponding Merkle tree are missing. @@ -518,7 +546,7 @@ func (fs *filesystem) lookupAndVerifyLocked(ctx context.Context, parent *dentry, // TODO(b/167752508): Investigate possible ways to differentiate // cases that both files are deleted from cases that they never // exist in the file system. - return nil, alertIntegrityViolation(childErr, fmt.Sprintf("Failed to find file %s", parentPath+"/"+name)) + return nil, alertIntegrityViolation(fmt.Sprintf("Failed to find file %s", parentPath+"/"+name)) } mask := uint32(linux.STATX_TYPE | linux.STATX_MODE | linux.STATX_UID | linux.STATX_GID) @@ -762,7 +790,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // missing, it indicates an unexpected modification to the file system. if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("File %s expected but not found", path)) + return nil, alertIntegrityViolation(fmt.Sprintf("File %s expected but not found", path)) } return nil, err } @@ -785,7 +813,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf // the file system. if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err } @@ -810,7 +838,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf }) if err != nil { if err == syserror.ENOENT { - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", path)) + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", path)) } return nil, err } @@ -828,7 +856,7 @@ func (d *dentry) openLocked(ctx context.Context, rp *vfs.ResolvingPath, opts *vf if err != nil { if err == syserror.ENOENT { parentPath, _ := d.fs.vfsfs.VirtualFilesystem().PathnameWithDeleted(ctx, d.fs.rootDentry.lowerVD, d.parent.lowerVD) - return nil, alertIntegrityViolation(err, fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) + return nil, alertIntegrityViolation(fmt.Sprintf("Merkle file for %s expected but not found", parentPath)) } return nil, err } diff --git a/pkg/sentry/fsimpl/verity/save_restore.go b/pkg/sentry/fsimpl/verity/save_restore.go new file mode 100644 index 000000000..46b064342 --- /dev/null +++ b/pkg/sentry/fsimpl/verity/save_restore.go @@ -0,0 +1,27 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package verity + +import ( + "sync/atomic" + + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +func (d *dentry) afterLoad() { + if atomic.LoadInt64(&d.refs) != -1 { + refsvfs2.Register(d) + } +} diff --git a/pkg/sentry/fsimpl/verity/verity.go b/pkg/sentry/fsimpl/verity/verity.go index 8dc9e26bc..d24c839bb 100644 --- a/pkg/sentry/fsimpl/verity/verity.go +++ b/pkg/sentry/fsimpl/verity/verity.go @@ -23,6 +23,7 @@ package verity import ( "fmt" + "math" "strconv" "sync/atomic" @@ -31,6 +32,7 @@ import ( "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/marshal/primitive" "gvisor.dev/gvisor/pkg/merkletree" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/arch" fslock "gvisor.dev/gvisor/pkg/sentry/fs/lock" "gvisor.dev/gvisor/pkg/sentry/kernel" @@ -41,32 +43,62 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) -// Name is the default filesystem name. -const Name = "verity" +const ( + // Name is the default filesystem name. + Name = "verity" -// merklePrefix is the prefix of the Merkle tree files. For example, the Merkle -// tree file for "/foo" is "/.merkle.verity.foo". -const merklePrefix = ".merkle.verity." + // merklePrefix is the prefix of the Merkle tree files. For example, the Merkle + // tree file for "/foo" is "/.merkle.verity.foo". + merklePrefix = ".merkle.verity." -// merkleoffsetInParentXattr is the extended attribute name specifying the -// offset of child hash in its parent's Merkle tree. -const merkleOffsetInParentXattr = "user.merkle.offset" + // merkleOffsetInParentXattr is the extended attribute name specifying the + // offset of the child hash in its parent's Merkle tree. + merkleOffsetInParentXattr = "user.merkle.offset" -// merkleSizeXattr is the extended attribute name specifying the size of data -// hashed by the corresponding Merkle tree. For a file, it's the size of the -// whole file. For a directory, it's the size of all its children's hashes. -const merkleSizeXattr = "user.merkle.size" + // merkleSizeXattr is the extended attribute name specifying the size of data + // hashed by the corresponding Merkle tree. For a regular file, this is the + // file size. For a directory, this is the size of all its children's hashes. + merkleSizeXattr = "user.merkle.size" -// sizeOfStringInt32 is the size for a 32 bit integer stored as string in -// extended attributes. The maximum value of a 32 bit integer is 10 digits. -const sizeOfStringInt32 = 10 + // sizeOfStringInt32 is the size for a 32 bit integer stored as string in + // extended attributes. The maximum value of a 32 bit integer has 10 digits. + sizeOfStringInt32 = 10 +) -// noCrashOnVerificationFailure indicates whether the sandbox should panic -// whenever verification fails. If true, an error is returned instead of -// panicking. This should only be set for tests. -// TOOD(b/165661693): Decide whether to panic or return error based on this -// flag. -var noCrashOnVerificationFailure bool +var ( + // noCrashOnVerificationFailure indicates whether the sandbox should panic + // whenever verification fails. If true, an error is returned instead of + // panicking. This should only be set for tests. + // + // TODO(b/165661693): Decide whether to panic or return error based on this + // flag. + noCrashOnVerificationFailure bool + + // verityMu synchronizes concurrent operations that enable verity and perform + // verification checks. + verityMu sync.RWMutex +) + +// HashAlgorithm is a type specifying the algorithm used to hash the file +// content. +type HashAlgorithm int + +// Currently supported hashing algorithms include SHA256 and SHA512. +const ( + SHA256 HashAlgorithm = iota + SHA512 +) + +func (alg HashAlgorithm) toLinuxHashAlg() int { + switch alg { + case SHA256: + return linux.FS_VERITY_HASH_ALG_SHA256 + case SHA512: + return linux.FS_VERITY_HASH_ALG_SHA512 + default: + return 0 + } +} // FilesystemType implements vfs.FilesystemType. // @@ -97,6 +129,10 @@ type filesystem struct { // stores the root hash of the whole file system in bytes. rootDentry *dentry + // alg is the algorithms used to hash the files in the verity file + // system. + alg HashAlgorithm + // renameMu synchronizes renaming with non-renaming operations in order // to ensure consistent lock ordering between dentry.dirMu in different // dentries. @@ -125,6 +161,10 @@ type InternalFilesystemOptions struct { // LowerName is the name of the filesystem wrapped by verity fs. LowerName string + // Alg is the algorithms used to hash the files in the verity file + // system. + Alg HashAlgorithm + // RootHash is the root hash of the overall verity file system. RootHash []byte @@ -153,10 +193,10 @@ func (FilesystemType) Release(ctx context.Context) {} // alertIntegrityViolation alerts a violation of integrity, which usually means // unexpected modification to the file system is detected. In -// noCrashOnVerificationFailure mode, it returns an error, otherwise it panic. -func alertIntegrityViolation(err error, msg string) error { +// noCrashOnVerificationFailure mode, it returns EIO, otherwise it panic. +func alertIntegrityViolation(msg string) error { if noCrashOnVerificationFailure { - return err + return syserror.EIO } panic(msg) } @@ -183,6 +223,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt fs := &filesystem{ creds: creds.Fork(), + alg: iopts.Alg, lowerMount: mnt, allowRuntimeEnable: iopts.AllowRuntimeEnable, } @@ -236,7 +277,7 @@ func (fstype FilesystemType) GetFilesystem(ctx context.Context, vfsObj *vfs.Virt // the root Merkle file, or it's never generated. fs.vfsfs.DecRef(ctx) d.DecRef(ctx) - return nil, nil, alertIntegrityViolation(err, "Failed to find root Merkle file") + return nil, nil, alertIntegrityViolation("Failed to find root Merkle file") } d.lowerMerkleVD = lowerMerkleVD @@ -289,11 +330,12 @@ type dentry struct { // fs is the owning filesystem. fs is immutable. fs *filesystem - // mode, uid and gid are the file mode, owner, and group of the file in - // the underlying file system. + // mode, uid, gid and size are the file mode, owner, group, and size of + // the file in the underlying file system. mode uint32 uid uint32 gid uint32 + size uint32 // parent is the dentry corresponding to this dentry's parent directory. // name is this dentry's name in parent. If this dentry is a filesystem @@ -331,22 +373,25 @@ func (fs *filesystem) newDentry() *dentry { fs: fs, } d.vfsd.Init(d) + refsvfs2.Register(d) return d } // IncRef implements vfs.DentryImpl.IncRef. func (d *dentry) IncRef() { - atomic.AddInt64(&d.refs, 1) + r := atomic.AddInt64(&d.refs, 1) + refsvfs2.LogIncRef(d, r) } // TryIncRef implements vfs.DentryImpl.TryIncRef. func (d *dentry) TryIncRef() bool { for { - refs := atomic.LoadInt64(&d.refs) - if refs <= 0 { + r := atomic.LoadInt64(&d.refs) + if r <= 0 { return false } - if atomic.CompareAndSwapInt64(&d.refs, refs, refs+1) { + if atomic.CompareAndSwapInt64(&d.refs, r, r+1) { + refsvfs2.LogTryIncRef(d, r+1) return true } } @@ -354,15 +399,27 @@ func (d *dentry) TryIncRef() bool { // DecRef implements vfs.DentryImpl.DecRef. func (d *dentry) DecRef(ctx context.Context) { - if refs := atomic.AddInt64(&d.refs, -1); refs == 0 { + r := atomic.AddInt64(&d.refs, -1) + refsvfs2.LogDecRef(d, r) + if r == 0 { d.fs.renameMu.Lock() d.checkDropLocked(ctx) d.fs.renameMu.Unlock() - } else if refs < 0 { + } else if r < 0 { panic("verity.dentry.DecRef() called without holding a reference") } } +func (d *dentry) decRefLocked(ctx context.Context) { + r := atomic.AddInt64(&d.refs, -1) + refsvfs2.LogDecRef(d, r) + if r == 0 { + d.checkDropLocked(ctx) + } else if r < 0 { + panic("verity.dentry.decRefLocked() called without holding a reference") + } +} + // checkDropLocked should be called after d's reference count becomes 0 or it // becomes deleted. func (d *dentry) checkDropLocked(ctx context.Context) { @@ -393,23 +450,36 @@ func (d *dentry) destroyLocked(ctx context.Context) { if d.lowerVD.Ok() { d.lowerVD.DecRef(ctx) } - if d.lowerMerkleVD.Ok() { d.lowerMerkleVD.DecRef(ctx) } - if d.parent != nil { d.parent.dirMu.Lock() if !d.vfsd.IsDead() { delete(d.parent.children, d.name) } d.parent.dirMu.Unlock() - if refs := atomic.AddInt64(&d.parent.refs, -1); refs == 0 { - d.parent.checkDropLocked(ctx) - } else if refs < 0 { - panic("verity.dentry.DecRef() called without holding a reference") - } + d.parent.decRefLocked(ctx) } + refsvfs2.Unregister(d) +} + +// RefType implements refsvfs2.CheckedObject.Type. +func (d *dentry) RefType() string { + return "verity.dentry" +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (d *dentry) LeakMessage() string { + return fmt.Sprintf("[verity.dentry %p] reference count of %d instead of -1", d, atomic.LoadInt64(&d.refs)) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +// +// This should only be set to true for debugging purposes, as it can generate an +// extremely large amount of output and drastically degrade performance. +func (d *dentry) LogRefs() bool { + return false } // InotifyWithParent implements vfs.DentryImpl.InotifyWithParent. @@ -448,6 +518,16 @@ func (d *dentry) verityEnabled() bool { return !d.fs.allowRuntimeEnable || len(d.hash) != 0 } +// getLowerAt returns the dentry in the underlying file system, which is +// represented by filename relative to d. +func (d *dentry) getLowerAt(ctx context.Context, vfsObj *vfs.VirtualFilesystem, filename string) (vfs.VirtualDentry, error) { + return vfsObj.GetDentryAt(ctx, d.fs.creds, &vfs.PathOperation{ + Root: d.lowerVD, + Start: d.lowerVD, + Path: fspath.Parse(filename), + }, &vfs.GetDentryOptions{}) +} + func (d *dentry) readlink(ctx context.Context) (string, error) { return d.fs.vfsfs.VirtualFilesystem().ReadlinkAt(ctx, d.fs.creds, &vfs.PathOperation{ Root: d.lowerVD, @@ -489,6 +569,10 @@ type fileDescription struct { // directory that contains the current file/directory. This is only used // if allowRuntimeEnable is set to true. parentMerkleWriter *vfs.FileDescription + + // off is the file offset. off is protected by mu. + mu sync.Mutex `state:"nosave"` + off int64 } // Release implements vfs.FileDescriptionImpl.Release. @@ -524,6 +608,32 @@ func (fd *fileDescription) SetStat(ctx context.Context, opts vfs.SetStatOptions) return syserror.EPERM } +// Seek implements vfs.FileDescriptionImpl.Seek. +func (fd *fileDescription) Seek(ctx context.Context, offset int64, whence int32) (int64, error) { + fd.mu.Lock() + defer fd.mu.Unlock() + n := int64(0) + switch whence { + case linux.SEEK_SET: + // use offset as specified + case linux.SEEK_CUR: + n = fd.off + case linux.SEEK_END: + n = int64(fd.d.size) + default: + return 0, syserror.EINVAL + } + if offset > math.MaxInt64-n { + return 0, syserror.EINVAL + } + offset += n + if offset < 0 { + return 0, syserror.EINVAL + } + fd.off = offset + return offset, nil +} + // generateMerkle generates a Merkle tree file for fd. If fd points to a file // /foo/bar, a Merkle tree file /foo/.merkle.verity.bar is generated. The hash // of the generated Merkle tree and the data size is returned. If fd points to @@ -546,6 +656,8 @@ func (fd *fileDescription) generateMerkle(ctx context.Context) ([]byte, uint64, params := &merkletree.GenerateParams{ TreeReader: &merkleReader, TreeWriter: &merkleWriter, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), } switch atomic.LoadUint32(&fd.d.mode) & linux.S_IFMT { @@ -611,7 +723,7 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui // or directory other than the root, the parent Merkle tree file should // have also been initialized. if fd.lowerFD == nil || fd.merkleReader == nil || fd.merkleWriter == nil || (fd.parentMerkleWriter == nil && fd.d != fd.d.fs.rootDentry) { - return 0, alertIntegrityViolation(syserror.EIO, "Unexpected verity fd: missing expected underlying fds") + return 0, alertIntegrityViolation("Unexpected verity fd: missing expected underlying fds") } hash, dataSize, err := fd.generateMerkle(ctx) @@ -657,6 +769,9 @@ func (fd *fileDescription) enableVerity(ctx context.Context, uio usermem.IO) (ui // measureVerity returns the hash of fd, saved in verityDigest. func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, verityDigest usermem.Addr) (uintptr, error) { t := kernel.TaskFromContext(ctx) + if t == nil { + return 0, syserror.EINVAL + } var metadata linux.DigestMetadata // If allowRuntimeEnable is true, an empty fd.d.hash indicates that @@ -667,7 +782,7 @@ func (fd *fileDescription) measureVerity(ctx context.Context, uio usermem.IO, ve if fd.d.fs.allowRuntimeEnable { return 0, syserror.ENODATA } - return 0, alertIntegrityViolation(syserror.ENODATA, "Ioctl measureVerity: no hash found") + return 0, alertIntegrityViolation("Ioctl measureVerity: no hash found") } // The first part of VerityDigest is the metadata. @@ -702,6 +817,9 @@ func (fd *fileDescription) verityFlags(ctx context.Context, uio usermem.IO, flag } t := kernel.TaskFromContext(ctx) + if t == nil { + return 0, syserror.EINVAL + } _, err := primitive.CopyInt32Out(t, flags, f) return 0, err } @@ -722,6 +840,16 @@ func (fd *fileDescription) Ioctl(ctx context.Context, uio usermem.IO, args arch. } } +// Read implements vfs.FileDescriptionImpl.Read. +func (fd *fileDescription) Read(ctx context.Context, dst usermem.IOSequence, opts vfs.ReadOptions) (int64, error) { + // Implement Read with PRead by setting offset. + fd.mu.Lock() + n, err := fd.PRead(ctx, dst, fd.off, opts) + fd.off += n + fd.mu.Unlock() + return n, err +} + // PRead implements vfs.FileDescriptionImpl.PRead. func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, offset int64, opts vfs.ReadOptions) (int64, error) { // No need to verify if the file is not enabled yet in @@ -742,7 +870,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // contains the expected xattrs. If the xattr does not exist, it // indicates unexpected modifications to the file system. if err == syserror.ENODATA { - return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) + return 0, alertIntegrityViolation(fmt.Sprintf("Failed to get xattr %s: %v", merkleSizeXattr, err)) } if err != nil { return 0, err @@ -752,7 +880,7 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of // unexpected modifications to the file system. size, err := strconv.Atoi(dataSize) if err != nil { - return 0, alertIntegrityViolation(err, fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) + return 0, alertIntegrityViolation(fmt.Sprintf("Failed to convert xattr %s to int: %v", merkleSizeXattr, err)) } dataReader := vfs.FileReadWriteSeeker{ @@ -766,25 +894,37 @@ func (fd *fileDescription) PRead(ctx context.Context, dst usermem.IOSequence, of } n, err := merkletree.Verify(&merkletree.VerifyParams{ - Out: dst.Writer(ctx), - File: &dataReader, - Tree: &merkleReader, - Size: int64(size), - Name: fd.d.name, - Mode: fd.d.mode, - UID: fd.d.uid, - GID: fd.d.gid, + Out: dst.Writer(ctx), + File: &dataReader, + Tree: &merkleReader, + Size: int64(size), + Name: fd.d.name, + Mode: fd.d.mode, + UID: fd.d.uid, + GID: fd.d.gid, + //TODO(b/156980949): Support passing other hash algorithms. + HashAlgorithms: fd.d.fs.alg.toLinuxHashAlg(), ReadOffset: offset, ReadSize: dst.NumBytes(), Expected: fd.d.hash, DataAndTreeInSameFile: false, }) if err != nil { - return 0, alertIntegrityViolation(syserror.EIO, fmt.Sprintf("Verification failed: %v", err)) + return 0, alertIntegrityViolation(fmt.Sprintf("Verification failed: %v", err)) } return n, err } +// PWrite implements vfs.FileDescriptionImpl.PWrite. +func (fd *fileDescription) PWrite(ctx context.Context, src usermem.IOSequence, offset int64, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EROFS +} + +// Write implements vfs.FileDescriptionImpl.Write. +func (fd *fileDescription) Write(ctx context.Context, src usermem.IOSequence, opts vfs.WriteOptions) (int64, error) { + return 0, syserror.EROFS +} + // LockPOSIX implements vfs.FileDescriptionImpl.LockPOSIX. func (fd *fileDescription) LockPOSIX(ctx context.Context, uid fslock.UniqueID, t fslock.LockType, start, length uint64, whence int16, block fslock.Blocker) error { return fd.lowerFD.LockPOSIX(ctx, uid, t, start, length, whence, block) diff --git a/pkg/sentry/fsimpl/verity/verity_test.go b/pkg/sentry/fsimpl/verity/verity_test.go index e301d35f5..b2da9dd96 100644 --- a/pkg/sentry/fsimpl/verity/verity_test.go +++ b/pkg/sentry/fsimpl/verity/verity_test.go @@ -25,10 +25,12 @@ import ( "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/fspath" "gvisor.dev/gvisor/pkg/sentry/arch" + "gvisor.dev/gvisor/pkg/sentry/fsimpl/testutil" "gvisor.dev/gvisor/pkg/sentry/fsimpl/tmpfs" + "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" - "gvisor.dev/gvisor/pkg/sentry/kernel/contexttest" "gvisor.dev/gvisor/pkg/sentry/vfs" + "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -41,11 +43,18 @@ const maxDataSize = 100000 // newVerityRoot creates a new verity mount, and returns the root. The // underlying file system is tmpfs. If the error is not nil, then cleanup // should be called when the root is no longer needed. -func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, vfs.VirtualDentry, error) { +func newVerityRoot(t *testing.T, hashAlg HashAlgorithm) (*vfs.VirtualFilesystem, vfs.VirtualDentry, *kernel.Task, error) { + k, err := testutil.Boot() + if err != nil { + t.Fatalf("testutil.Boot: %v", err) + } + + ctx := k.SupervisorContext() + rand.Seed(time.Now().UnixNano()) vfsObj := &vfs.VirtualFilesystem{} if err := vfsObj.Init(ctx); err != nil { - return nil, vfs.VirtualDentry{}, fmt.Errorf("VFS init: %v", err) + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("VFS init: %v", err) } vfsObj.MustRegisterFilesystemType("verity", FilesystemType{}, &vfs.RegisterFilesystemTypeOptions{ @@ -61,22 +70,33 @@ func newVerityRoot(ctx context.Context, t *testing.T) (*vfs.VirtualFilesystem, v InternalData: InternalFilesystemOptions{ RootMerkleFileName: rootMerkleFilename, LowerName: "tmpfs", + Alg: hashAlg, AllowRuntimeEnable: true, NoCrashOnVerificationFailure: true, }, }, }) if err != nil { - return nil, vfs.VirtualDentry{}, fmt.Errorf("NewMountNamespace: %v", err) + return nil, vfs.VirtualDentry{}, nil, fmt.Errorf("NewMountNamespace: %v", err) } root := mntns.Root() root.IncRef() + + // Use lowerRoot in the task as we modify the lower file system + // directly in many tests. + lowerRoot := root.Dentry().Impl().(*dentry).lowerVD + tc := k.NewThreadGroup(nil, k.RootPIDNamespace(), kernel.NewSignalHandlers(), linux.SIGCHLD, k.GlobalInit().Limits()) + task, err := testutil.CreateTask(ctx, "name", tc, mntns, lowerRoot, lowerRoot) + if err != nil { + t.Fatalf("testutil.CreateTask: %v", err) + } + t.Helper() t.Cleanup(func() { root.DecRef(ctx) mntns.DecRef(ctx) }) - return vfsObj, root, nil + return vfsObj, root, task, nil } // newFileFD creates a new file in the verity mount, and returns the FD. The FD @@ -142,207 +162,296 @@ func corruptRandomBit(ctx context.Context, fd *vfs.FileDescription, size int) er return nil } +var hashAlgs = []HashAlgorithm{SHA256, SHA512} + // TestOpen ensures that when a file is created, the corresponding Merkle tree // file and the root Merkle tree file exist. func TestOpen(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Ensure that the corresponding Merkle tree file is created. - lowerRoot := root.Dentry().Impl().(*dentry).lowerVD - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerRoot, - Start: lowerRoot, - Path: fspath.Parse(merklePrefix + filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }); err != nil { - t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + if _, _, err := newFileFD(ctx, vfsObj, root, filename, 0644); err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Ensure that the corresponding Merkle tree file is created. + lowerRoot := root.Dentry().Impl().(*dentry).lowerVD + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(merklePrefix + filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }); err != nil { + t.Errorf("OpenAt Merkle tree file %s: %v", merklePrefix+filename, err) + } + + // Ensure the root merkle tree file is created. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerRoot, + Start: lowerRoot, + Path: fspath.Parse(merklePrefix + rootMerkleFilename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }); err != nil { + t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err) + } } +} - // Ensure the root merkle tree file is created. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerRoot, - Start: lowerRoot, - Path: fspath.Parse(merklePrefix + rootMerkleFilename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }); err != nil { - t.Errorf("OpenAt root Merkle tree file %s: %v", merklePrefix+rootMerkleFilename, err) +// TestPReadUnmodifiedFileSucceeds ensures that pread from an untouched verity +// file succeeds after enabling verity for it. +func TestPReadUnmodifiedFileSucceeds(t *testing.T) { + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + buf := make([]byte, size) + n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.PRead: %v", err) + } + + if n != int64(size) { + t.Errorf("fd.PRead got read length %d, want %d", n, size) + } } } -// TestUnmodifiedFileSucceeds ensures that read from an untouched verity file -// succeeds after enabling verity for it. +// TestReadUnmodifiedFileSucceeds ensures that read from an untouched verity +// file succeeds after enabling verity for it. func TestReadUnmodifiedFileSucceeds(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirm a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - buf := make([]byte, size) - n, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}) - if err != nil && err != io.EOF { - t.Fatalf("fd.PRead: %v", err) - } - - if n != int64(size) { - t.Errorf("fd.PRead got read length %d, want %d", n, size) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirm a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + buf := make([]byte, size) + n, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}) + if err != nil && err != io.EOF { + t.Fatalf("fd.Read: %v", err) + } + + if n != int64(size) { + t.Errorf("fd.PRead got read length %d, want %d", n, size) + } } } // TestReopenUnmodifiedFileSucceeds ensures that reopen an untouched verity file // succeeds after enabling verity for it. func TestReopenUnmodifiedFileSucceeds(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirms a normal read succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Ensure reopening the verity enabled file succeeds. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular, - }); err != nil { - t.Errorf("reopen enabled file failed: %v", err) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirms a normal read succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Ensure reopening the verity enabled file succeeds. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err != nil { + t.Errorf("reopen enabled file failed: %v", err) + } } } -// TestModifiedFileFails ensures that read from a modified verity file fails. -func TestModifiedFileFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Open a new lowerFD that's read/writable. - lowerVD := fd.Impl().(*fileDescription).d.lowerVD - - lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerVD, - Start: lowerVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - if err := corruptRandomBit(ctx, lowerFD, size); err != nil { - t.Fatalf("corruptRandomBit: %v", err) +// TestPReadModifiedFileFails ensures that read from a modified verity file +// fails. +func TestPReadModifiedFileFails(t *testing.T) { + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerFD that's read/writable. + lowerVD := fd.Impl().(*fileDescription).d.lowerVD + + lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if err := corruptRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + // Confirm that read from the modified file fails. + buf := make([]byte, size) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { + t.Fatalf("fd.PRead succeeded, expected failure") + } } +} - // Confirm that read from the modified file fails. - buf := make([]byte, size) - if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { - t.Fatalf("fd.PRead succeeded with modified file") +// TestReadModifiedFileFails ensures that read from a modified verity file +// fails. +func TestReadModifiedFileFails(t *testing.T) { + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerFD that's read/writable. + lowerVD := fd.Impl().(*fileDescription).d.lowerVD + + lowerFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerVD, + Start: lowerVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if err := corruptRandomBit(ctx, lowerFD, size); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + // Confirm that read from the modified file fails. + buf := make([]byte, size) + if _, err := fd.Read(ctx, usermem.BytesIOSequence(buf), vfs.ReadOptions{}); err == nil { + t.Fatalf("fd.Read succeeded, expected failure") + } } } // TestModifiedMerkleFails ensures that read from a verity file fails if the // corresponding Merkle tree file is modified. func TestModifiedMerkleFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Open a new lowerMerkleFD that's read/writable. - lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD - - lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: lowerMerkleVD, - Start: lowerMerkleVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - // Flip a random bit in the Merkle tree file. - stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{}) - if err != nil { - t.Fatalf("stat: %v", err) - } - merkleSize := int(stat.Size) - if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil { - t.Fatalf("corruptRandomBit: %v", err) - } - - // Confirm that read from a file with modified Merkle tree fails. - buf := make([]byte, size) - if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { - fmt.Println(buf) - t.Fatalf("fd.PRead succeeded with modified Merkle file") + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, size, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerMerkleFD that's read/writable. + lowerMerkleVD := fd.Impl().(*fileDescription).d.lowerMerkleVD + + lowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: lowerMerkleVD, + Start: lowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + // Flip a random bit in the Merkle tree file. + stat, err := lowerMerkleFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + t.Fatalf("stat: %v", err) + } + merkleSize := int(stat.Size) + if err := corruptRandomBit(ctx, lowerMerkleFD, merkleSize); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + // Confirm that read from a file with modified Merkle tree fails. + buf := make([]byte, size) + if _, err := fd.PRead(ctx, usermem.BytesIOSequence(buf), 0 /* offset */, vfs.ReadOptions{}); err == nil { + fmt.Println(buf) + t.Fatalf("fd.PRead succeeded with modified Merkle file") + } } } @@ -350,142 +459,267 @@ func TestModifiedMerkleFails(t *testing.T) { // verity enabled directory fails if the hashes related to the target file in // the parent Merkle tree file is modified. func TestModifiedParentMerkleFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Enable verity on the parent directory. - parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - }) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("Ioctl: %v", err) - } - - // Open a new lowerMerkleFD that's read/writable. - parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD - - parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: parentLowerMerkleVD, - Start: parentLowerMerkleVD, - }, &vfs.OpenOptions{ - Flags: linux.O_RDWR, - }) - if err != nil { - t.Fatalf("OpenAt: %v", err) - } - - // Flip a random bit in the parent Merkle tree file. - // This parent directory contains only one child, so any random - // modification in the parent Merkle tree should cause verification - // failure when opening the child file. - stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{}) - if err != nil { - t.Fatalf("stat: %v", err) - } - parentMerkleSize := int(stat.Size) - if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { - t.Fatalf("corruptRandomBit: %v", err) - } - - parentLowerMerkleFD.DecRef(ctx) - - // Ensure reopening the verity enabled file fails. - if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ - Root: root, - Start: root, - Path: fspath.Parse(filename), - }, &vfs.OpenOptions{ - Flags: linux.O_RDONLY, - Mode: linux.ModeRegular, - }); err == nil { - t.Errorf("OpenAt file with modified parent Merkle succeeded") + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Enable verity on the parent directory. + parentFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + if _, err := parentFD.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + // Open a new lowerMerkleFD that's read/writable. + parentLowerMerkleVD := fd.Impl().(*fileDescription).d.parent.lowerMerkleVD + + parentLowerMerkleFD, err := vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: parentLowerMerkleVD, + Start: parentLowerMerkleVD, + }, &vfs.OpenOptions{ + Flags: linux.O_RDWR, + }) + if err != nil { + t.Fatalf("OpenAt: %v", err) + } + + // Flip a random bit in the parent Merkle tree file. + // This parent directory contains only one child, so any random + // modification in the parent Merkle tree should cause verification + // failure when opening the child file. + stat, err := parentLowerMerkleFD.Stat(ctx, vfs.StatOptions{}) + if err != nil { + t.Fatalf("stat: %v", err) + } + parentMerkleSize := int(stat.Size) + if err := corruptRandomBit(ctx, parentLowerMerkleFD, parentMerkleSize); err != nil { + t.Fatalf("corruptRandomBit: %v", err) + } + + parentLowerMerkleFD.DecRef(ctx) + + // Ensure reopening the verity enabled file fails. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err == nil { + t.Errorf("OpenAt file with modified parent Merkle succeeded") + } } } // TestUnmodifiedStatSucceeds ensures that stat of an untouched verity file // succeeds after enabling verity for it. func TestUnmodifiedStatSucceeds(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } - - // Enable verity on the file and confirms stat succeeds. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("fd.Ioctl: %v", err) - } - - if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil { - t.Errorf("fd.Stat: %v", err) + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file and confirms stat succeeds. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("fd.Ioctl: %v", err) + } + + if _, err := fd.Stat(ctx, vfs.StatOptions{}); err != nil { + t.Errorf("fd.Stat: %v", err) + } } } // TestModifiedStatFails checks that getting stat for a file with modified stat // should fail. func TestModifiedStatFails(t *testing.T) { - ctx := contexttest.Context(t) - vfsObj, root, err := newVerityRoot(ctx, t) - if err != nil { - t.Fatalf("newVerityRoot: %v", err) - } - - filename := "verity-test-file" - fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) - if err != nil { - t.Fatalf("newFileFD: %v", err) - } + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("fd.Ioctl: %v", err) + } + + lowerFD := fd.Impl().(*fileDescription).lowerFD + // Change the stat of the underlying file, and check that stat fails. + if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{ + Stat: linux.Statx{ + Mask: uint32(linux.STATX_MODE), + Mode: 0777, + }, + }); err != nil { + t.Fatalf("lowerFD.SetStat: %v", err) + } - // Enable verity on the file. - var args arch.SyscallArguments - args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} - if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { - t.Fatalf("fd.Ioctl: %v", err) + if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil { + t.Errorf("fd.Stat succeeded when it should fail") + } } +} - lowerFD := fd.Impl().(*fileDescription).lowerFD - // Change the stat of the underlying file, and check that stat fails. - if err := lowerFD.SetStat(ctx, vfs.SetStatOptions{ - Stat: linux.Statx{ - Mask: uint32(linux.STATX_MODE), - Mode: 0777, +// TestOpenDeletedOrRenamedFileFails ensures that opening a deleted/renamed +// verity enabled file or the corresponding Merkle tree file fails with the +// verify error. +func TestOpenDeletedFileFails(t *testing.T) { + testCases := []struct { + // Tests removing files is remove is true. Otherwise tests + // renaming files. + remove bool + // The original file is removed/renamed if changeFile is true. + changeFile bool + // The Merkle tree file is removed/renamed if changeMerkleFile + // is true. + changeMerkleFile bool + }{ + { + remove: true, + changeFile: true, + changeMerkleFile: false, + }, + { + remove: true, + changeFile: false, + changeMerkleFile: true, + }, + { + remove: false, + changeFile: true, + changeMerkleFile: false, + }, + { + remove: false, + changeFile: true, + changeMerkleFile: false, }, - }); err != nil { - t.Fatalf("lowerFD.SetStat: %v", err) } - - if _, err := fd.Stat(ctx, vfs.StatOptions{}); err == nil { - t.Errorf("fd.Stat succeeded when it should fail") + for _, tc := range testCases { + t.Run(fmt.Sprintf("remove:%t", tc.remove), func(t *testing.T) { + for _, alg := range hashAlgs { + vfsObj, root, ctx, err := newVerityRoot(t, alg) + if err != nil { + t.Fatalf("newVerityRoot: %v", err) + } + + filename := "verity-test-file" + fd, _, err := newFileFD(ctx, vfsObj, root, filename, 0644) + if err != nil { + t.Fatalf("newFileFD: %v", err) + } + + // Enable verity on the file. + var args arch.SyscallArguments + args[1] = arch.SyscallArgument{Value: linux.FS_IOC_ENABLE_VERITY} + if _, err := fd.Ioctl(ctx, nil /* uio */, args); err != nil { + t.Fatalf("Ioctl: %v", err) + } + + rootLowerVD := root.Dentry().Impl().(*dentry).lowerVD + if tc.remove { + if tc.changeFile { + if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(filename), + }); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + if tc.changeMerkleFile { + if err := vfsObj.UnlinkAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + filename), + }); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + } else { + newFilename := "renamed-test-file" + if tc.changeFile { + if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(filename), + }, &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(newFilename), + }, &vfs.RenameOptions{}); err != nil { + t.Fatalf("RenameAt: %v", err) + } + } + if tc.changeMerkleFile { + if err := vfsObj.RenameAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + filename), + }, &vfs.PathOperation{ + Root: rootLowerVD, + Start: rootLowerVD, + Path: fspath.Parse(merklePrefix + newFilename), + }, &vfs.RenameOptions{}); err != nil { + t.Fatalf("UnlinkAt: %v", err) + } + } + } + + // Ensure reopening the verity enabled file fails. + if _, err = vfsObj.OpenAt(ctx, auth.CredentialsFromContext(ctx), &vfs.PathOperation{ + Root: root, + Start: root, + Path: fspath.Parse(filename), + }, &vfs.OpenOptions{ + Flags: linux.O_RDONLY, + Mode: linux.ModeRegular, + }); err != syserror.EIO { + t.Errorf("got OpenAt error: %v, expected EIO", err) + } + } + }) } } diff --git a/pkg/sentry/hostfd/BUILD b/pkg/sentry/hostfd/BUILD index 364a78306..db3b0d0a0 100644 --- a/pkg/sentry/hostfd/BUILD +++ b/pkg/sentry/hostfd/BUILD @@ -6,10 +6,12 @@ go_library( name = "hostfd", srcs = [ "hostfd.go", + "hostfd_linux.go", "hostfd_unsafe.go", ], visibility = ["//pkg/sentry:internal"], deps = [ + "//pkg/log", "//pkg/safemem", "//pkg/sync", "@org_golang_x_sys//unix:go_default_library", diff --git a/pkg/sentry/hostfd/hostfd_linux.go b/pkg/sentry/hostfd/hostfd_linux.go new file mode 100644 index 000000000..1cabc848f --- /dev/null +++ b/pkg/sentry/hostfd/hostfd_linux.go @@ -0,0 +1,18 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package hostfd + +// maxIov is the maximum permitted size of a struct iovec array. +const maxIov = 1024 // UIO_MAXIOV diff --git a/pkg/sentry/hostfd/hostfd_unsafe.go b/pkg/sentry/hostfd/hostfd_unsafe.go index cd4dc67fb..694371b1c 100644 --- a/pkg/sentry/hostfd/hostfd_unsafe.go +++ b/pkg/sentry/hostfd/hostfd_unsafe.go @@ -20,6 +20,7 @@ import ( "unsafe" "golang.org/x/sys/unix" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/safemem" ) @@ -44,6 +45,10 @@ func Preadv2(fd int32, dsts safemem.BlockSeq, offset int64, flags uint32) (uint6 } } else { iovs := safemem.IovecsFromBlockSeq(dsts) + if len(iovs) > maxIov { + log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov) + iovs = iovs[:maxIov] + } n, _, e = syscall.Syscall6(unix.SYS_PREADV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags)) } if e != 0 { @@ -76,6 +81,10 @@ func Pwritev2(fd int32, srcs safemem.BlockSeq, offset int64, flags uint32) (uint } } else { iovs := safemem.IovecsFromBlockSeq(srcs) + if len(iovs) > maxIov { + log.Debugf("hostfd.Preadv2: truncating from %d iovecs to %d", len(iovs), maxIov) + iovs = iovs[:maxIov] + } n, _, e = syscall.Syscall6(unix.SYS_PWRITEV2, uintptr(fd), uintptr((unsafe.Pointer)(&iovs[0])), uintptr(len(iovs)), uintptr(offset), 0 /* pos_h */, uintptr(flags)) } if e != 0 { diff --git a/pkg/sentry/inet/inet.go b/pkg/sentry/inet/inet.go index fbe6d6aa6..f31277d30 100644 --- a/pkg/sentry/inet/inet.go +++ b/pkg/sentry/inet/inet.go @@ -32,9 +32,13 @@ type Stack interface { InterfaceAddrs() map[int32][]InterfaceAddr // AddInterfaceAddr adds an address to the network interface identified by - // index. + // idx. AddInterfaceAddr(idx int32, addr InterfaceAddr) error + // RemoveInterfaceAddr removes an address from the network interface + // identified by idx. + RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error + // SupportsIPv6 returns true if the stack supports IPv6 connectivity. SupportsIPv6() bool diff --git a/pkg/sentry/inet/test_stack.go b/pkg/sentry/inet/test_stack.go index 1779cc6f3..9ebeba8a3 100644 --- a/pkg/sentry/inet/test_stack.go +++ b/pkg/sentry/inet/test_stack.go @@ -15,6 +15,9 @@ package inet import ( + "bytes" + "fmt" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/stack" ) @@ -58,6 +61,24 @@ func (s *TestStack) AddInterfaceAddr(idx int32, addr InterfaceAddr) error { return nil } +// RemoveInterfaceAddr implements Stack.RemoveInterfaceAddr. +func (s *TestStack) RemoveInterfaceAddr(idx int32, addr InterfaceAddr) error { + interfaceAddrs, ok := s.InterfaceAddrsMap[idx] + if !ok { + return fmt.Errorf("unknown idx: %d", idx) + } + + var filteredAddrs []InterfaceAddr + for _, interfaceAddr := range interfaceAddrs { + if !bytes.Equal(interfaceAddr.Addr, addr.Addr) { + filteredAddrs = append(filteredAddrs, addr) + } + } + s.InterfaceAddrsMap[idx] = filteredAddrs + + return nil +} + // SupportsIPv6 implements Stack.SupportsIPv6. func (s *TestStack) SupportsIPv6() bool { return s.SupportsIPv6Flag diff --git a/pkg/sentry/kernel/BUILD b/pkg/sentry/kernel/BUILD index c0de72eef..90dd4a047 100644 --- a/pkg/sentry/kernel/BUILD +++ b/pkg/sentry/kernel/BUILD @@ -79,7 +79,7 @@ go_template_instance( out = "fd_table_refs.go", package = "kernel", prefix = "FDTable", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "FDTable", }, @@ -90,7 +90,7 @@ go_template_instance( out = "fs_context_refs.go", package = "kernel", prefix = "FSContext", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "FSContext", }, @@ -101,7 +101,7 @@ go_template_instance( out = "ipc_namespace_refs.go", package = "kernel", prefix = "IPCNamespace", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "IPCNamespace", }, @@ -112,7 +112,7 @@ go_template_instance( out = "process_group_refs.go", package = "kernel", prefix = "ProcessGroup", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "ProcessGroup", }, @@ -123,7 +123,7 @@ go_template_instance( out = "session_refs.go", package = "kernel", prefix = "Session", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "Session", }, @@ -229,7 +229,7 @@ go_library( "//pkg/marshal/primitive", "//pkg/metric", "//pkg/refs", - "//pkg/refs_vfs2", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/secio", "//pkg/sentry/arch", diff --git a/pkg/sentry/kernel/abstract_socket_namespace.go b/pkg/sentry/kernel/abstract_socket_namespace.go index 1b9721534..0ddbe5ff6 100644 --- a/pkg/sentry/kernel/abstract_socket_namespace.go +++ b/pkg/sentry/kernel/abstract_socket_namespace.go @@ -19,7 +19,7 @@ import ( "syscall" "gvisor.dev/gvisor/pkg/context" - "gvisor.dev/gvisor/pkg/refs_vfs2" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/socket/unix/transport" "gvisor.dev/gvisor/pkg/sync" ) @@ -27,7 +27,7 @@ import ( // +stateify savable type abstractEndpoint struct { ep transport.BoundEndpoint - socket refs_vfs2.RefCounter + socket refsvfs2.RefCounter name string ns *AbstractSocketNamespace } @@ -57,7 +57,7 @@ func NewAbstractSocketNamespace() *AbstractSocketNamespace { // its backing socket. type boundEndpoint struct { transport.BoundEndpoint - socket refs_vfs2.RefCounter + socket refsvfs2.RefCounter } // Release implements transport.BoundEndpoint.Release. @@ -89,7 +89,7 @@ func (a *AbstractSocketNamespace) BoundEndpoint(name string) transport.BoundEndp // // When the last reference managed by socket is dropped, ep may be removed from the // namespace. -func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refs_vfs2.RefCounter) error { +func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep transport.BoundEndpoint, socket refsvfs2.RefCounter) error { a.mu.Lock() defer a.mu.Unlock() @@ -109,7 +109,7 @@ func (a *AbstractSocketNamespace) Bind(ctx context.Context, name string, ep tran // Remove removes the specified socket at name from the abstract socket // namespace, if it has not yet been replaced. -func (a *AbstractSocketNamespace) Remove(name string, socket refs_vfs2.RefCounter) { +func (a *AbstractSocketNamespace) Remove(name string, socket refsvfs2.RefCounter) { a.mu.Lock() defer a.mu.Unlock() diff --git a/pkg/sentry/kernel/fd_table.go b/pkg/sentry/kernel/fd_table.go index 0ec7344cd..7aba31587 100644 --- a/pkg/sentry/kernel/fd_table.go +++ b/pkg/sentry/kernel/fd_table.go @@ -110,7 +110,7 @@ func (f *FDTable) saveDescriptorTable() map[int32]descriptor { func (f *FDTable) loadDescriptorTable(m map[int32]descriptor) { ctx := context.Background() - f.init() // Initialize table. + f.initNoLeakCheck() // Initialize table. f.used = 0 for fd, d := range m { if file, fileVFS2 := f.setAll(ctx, fd, d.file, d.fileVFS2, d.flags); file != nil || fileVFS2 != nil { @@ -240,6 +240,10 @@ func (f *FDTable) String() string { case fileVFS2 != nil: vfsObj := fileVFS2.Mount().Filesystem().VirtualFilesystem() + vd := fileVFS2.VirtualDentry() + if vd.Dentry() == nil { + panic(fmt.Sprintf("fd %d (type %T) has nil dentry: %#v", fd, fileVFS2.Impl(), fileVFS2)) + } name, err := vfsObj.PathnameWithDeleted(ctx, vfs.VirtualDentry{}, fileVFS2.VirtualDentry()) if err != nil { fmt.Fprintf(&buf, "<err: %v>\n", err) diff --git a/pkg/sentry/kernel/fd_table_unsafe.go b/pkg/sentry/kernel/fd_table_unsafe.go index da79e6627..3476551f3 100644 --- a/pkg/sentry/kernel/fd_table_unsafe.go +++ b/pkg/sentry/kernel/fd_table_unsafe.go @@ -31,14 +31,21 @@ type descriptorTable struct { slice unsafe.Pointer `state:".(map[int32]*descriptor)"` } -// init initializes the table. +// initNoLeakCheck initializes the table without enabling leak checking. // -// TODO(gvisor.dev/1486): Enable leak check for FDTable. -func (f *FDTable) init() { +// This is used when loading an FDTable after S/R, during which the ref count +// object itself will enable leak checking if necessary. +func (f *FDTable) initNoLeakCheck() { var slice []unsafe.Pointer // Empty slice. atomic.StorePointer(&f.slice, unsafe.Pointer(&slice)) } +// init initializes the table with leak checking. +func (f *FDTable) init() { + f.initNoLeakCheck() + f.EnableLeakCheck() +} + // get gets a file entry. // // The boolean indicates whether this was in range. diff --git a/pkg/sentry/kernel/fs_context.go b/pkg/sentry/kernel/fs_context.go index d46d1e1c1..41fb2a784 100644 --- a/pkg/sentry/kernel/fs_context.go +++ b/pkg/sentry/kernel/fs_context.go @@ -130,13 +130,15 @@ func (f *FSContext) Fork() *FSContext { f.root.IncRef() } - return &FSContext{ + ctx := &FSContext{ cwd: f.cwd, root: f.root, cwdVFS2: f.cwdVFS2, rootVFS2: f.rootVFS2, umask: f.umask, } + ctx.EnableLeakCheck() + return ctx } // WorkingDirectory returns the current working directory. @@ -147,19 +149,23 @@ func (f *FSContext) WorkingDirectory() *fs.Dirent { f.mu.Lock() defer f.mu.Unlock() - f.cwd.IncRef() + if f.cwd != nil { + f.cwd.IncRef() + } return f.cwd } // WorkingDirectoryVFS2 returns the current working directory. // -// This will return nil if called after f is destroyed, otherwise it will return -// a Dirent with a reference taken. +// This will return an empty vfs.VirtualDentry if called after f is +// destroyed, otherwise it will return a Dirent with a reference taken. func (f *FSContext) WorkingDirectoryVFS2() vfs.VirtualDentry { f.mu.Lock() defer f.mu.Unlock() - f.cwdVFS2.IncRef() + if f.cwdVFS2.Ok() { + f.cwdVFS2.IncRef() + } return f.cwdVFS2 } @@ -218,13 +224,15 @@ func (f *FSContext) RootDirectory() *fs.Dirent { // RootDirectoryVFS2 returns the current filesystem root. // -// This will return nil if called after f is destroyed, otherwise it will return -// a Dirent with a reference taken. +// This will return an empty vfs.VirtualDentry if called after f is +// destroyed, otherwise it will return a Dirent with a reference taken. func (f *FSContext) RootDirectoryVFS2() vfs.VirtualDentry { f.mu.Lock() defer f.mu.Unlock() - f.rootVFS2.IncRef() + if f.rootVFS2.Ok() { + f.rootVFS2.IncRef() + } return f.rootVFS2 } diff --git a/pkg/sentry/kernel/ipc_namespace.go b/pkg/sentry/kernel/ipc_namespace.go index 3f34ee0db..b87e40dd1 100644 --- a/pkg/sentry/kernel/ipc_namespace.go +++ b/pkg/sentry/kernel/ipc_namespace.go @@ -55,7 +55,7 @@ func (i *IPCNamespace) ShmRegistry() *shm.Registry { return i.shms } -// DecRef implements refs_vfs2.RefCounter.DecRef. +// DecRef implements refsvfs2.RefCounter.DecRef. func (i *IPCNamespace) DecRef(ctx context.Context) { i.IPCNamespaceRefs.DecRef(func() { i.shms.Release(ctx) diff --git a/pkg/sentry/kernel/kernel.go b/pkg/sentry/kernel/kernel.go index 0eb2bf7bd..9b2be44d4 100644 --- a/pkg/sentry/kernel/kernel.go +++ b/pkg/sentry/kernel/kernel.go @@ -430,9 +430,8 @@ func (k *Kernel) Init(args InitKernelArgs) error { // SaveTo saves the state of k to w. // // Preconditions: The kernel must be paused throughout the call to SaveTo. -func (k *Kernel) SaveTo(w wire.Writer) error { +func (k *Kernel) SaveTo(ctx context.Context, w wire.Writer) error { saveStart := time.Now() - ctx := k.SupervisorContext() // Do not allow other Kernel methods to affect it while it's being saved. k.extMu.Lock() @@ -446,38 +445,55 @@ func (k *Kernel) SaveTo(w wire.Writer) error { k.mf.StartEvictions() k.mf.WaitForEvictions() - // Flush write operations on open files so data reaches backing storage. - // This must come after MemoryFile eviction since eviction may cause file - // writes. - if err := k.tasks.flushWritesToFiles(ctx); err != nil { - return err - } + if VFS2Enabled { + // Discard unsavable mappings, such as those for host file descriptors. + if err := k.invalidateUnsavableMappings(ctx); err != nil { + return fmt.Errorf("failed to invalidate unsavable mappings: %v", err) + } - // Remove all epoll waiter objects from underlying wait queues. - // NOTE: for programs to resume execution in future snapshot scenarios, - // we will need to re-establish these waiter objects after saving. - k.tasks.unregisterEpollWaiters(ctx) + // Prepare filesystems for saving. This must be done after + // invalidateUnsavableMappings(), since dropping memory mappings may + // affect filesystem state (e.g. page cache reference counts). + if err := k.vfs.PrepareSave(ctx); err != nil { + return err + } + } else { + // Flush cached file writes to backing storage. This must come after + // MemoryFile eviction since eviction may cause file writes. + if err := k.flushWritesToFiles(ctx); err != nil { + return err + } - // Clear the dirent cache before saving because Dirents must be Loaded in a - // particular order (parents before children), and Loading dirents from a cache - // breaks that order. - if err := k.flushMountSourceRefs(ctx); err != nil { - return err - } + // Remove all epoll waiter objects from underlying wait queues. + // NOTE: for programs to resume execution in future snapshot scenarios, + // we will need to re-establish these waiter objects after saving. + k.tasks.unregisterEpollWaiters(ctx) - // Ensure that all inode and mount release operations have completed. - fs.AsyncBarrier() + // Clear the dirent cache before saving because Dirents must be Loaded in a + // particular order (parents before children), and Loading dirents from a cache + // breaks that order. + if err := k.flushMountSourceRefs(ctx); err != nil { + return err + } - // Once all fs work has completed (flushed references have all been released), - // reset mount mappings. This allows individual mounts to save how inodes map - // to filesystem resources. Without this, fs.Inodes cannot be restored. - fs.SaveInodeMappings() + // Ensure that all inode and mount release operations have completed. + fs.AsyncBarrier() - // Discard unsavable mappings, such as those for host file descriptors. - // This must be done after waiting for "asynchronous fs work", which - // includes async I/O that may touch application memory. - if err := k.invalidateUnsavableMappings(ctx); err != nil { - return fmt.Errorf("failed to invalidate unsavable mappings: %v", err) + // Once all fs work has completed (flushed references have all been released), + // reset mount mappings. This allows individual mounts to save how inodes map + // to filesystem resources. Without this, fs.Inodes cannot be restored. + fs.SaveInodeMappings() + + // Discard unsavable mappings, such as those for host file descriptors. + // This must be done after waiting for "asynchronous fs work", which + // includes async I/O that may touch application memory. + // + // TODO(gvisor.dev/issue/1624): This rationale is believed to be + // obsolete since AIO callbacks are now waited-for by Kernel.Pause(), + // but this order is conservatively retained for VFS1. + if err := k.invalidateUnsavableMappings(ctx); err != nil { + return fmt.Errorf("failed to invalidate unsavable mappings: %v", err) + } } // Save the CPUID FeatureSet before the rest of the kernel so we can @@ -486,14 +502,14 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // // N.B. This will also be saved along with the full kernel save below. cpuidStart := time.Now() - if _, err := state.Save(k.SupervisorContext(), w, k.FeatureSet()); err != nil { + if _, err := state.Save(ctx, w, k.FeatureSet()); err != nil { return err } log.Infof("CPUID save took [%s].", time.Since(cpuidStart)) // Save the kernel state. kernelStart := time.Now() - stats, err := state.Save(k.SupervisorContext(), w, k) + stats, err := state.Save(ctx, w, k) if err != nil { return err } @@ -502,7 +518,7 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // Save the memory file's state. memoryStart := time.Now() - if err := k.mf.SaveTo(k.SupervisorContext(), w); err != nil { + if err := k.mf.SaveTo(ctx, w); err != nil { return err } log.Infof("Memory save took [%s].", time.Since(memoryStart)) @@ -514,11 +530,9 @@ func (k *Kernel) SaveTo(w wire.Writer) error { // flushMountSourceRefs flushes the MountSources for all mounted filesystems // and open FDs. +// +// Preconditions: !VFS2Enabled. func (k *Kernel) flushMountSourceRefs(ctx context.Context) error { - if VFS2Enabled { - return nil // Not relevant. - } - // Flush all mount sources for currently mounted filesystems in each task. flushed := make(map[*fs.MountNamespace]struct{}) k.tasks.mu.RLock() @@ -561,13 +575,9 @@ func (ts *TaskSet) forEachFDPaused(ctx context.Context, f func(*fs.File, *vfs.Fi return err } -func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error { - // TODO(gvisor.dev/issue/1663): Add save support for VFS2. - if VFS2Enabled { - return nil - } - - return ts.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error { +// Preconditions: !VFS2Enabled. +func (k *Kernel) flushWritesToFiles(ctx context.Context) error { + return k.tasks.forEachFDPaused(ctx, func(file *fs.File, _ *vfs.FileDescription) error { if flags := file.Flags(); !flags.Write { return nil } @@ -589,37 +599,8 @@ func (ts *TaskSet) flushWritesToFiles(ctx context.Context) error { }) } -// Preconditions: The kernel must be paused. -func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { - invalidated := make(map[*mm.MemoryManager]struct{}) - k.tasks.mu.RLock() - defer k.tasks.mu.RUnlock() - for t := range k.tasks.Root.tids { - // We can skip locking Task.mu here since the kernel is paused. - if mm := t.tc.MemoryManager; mm != nil { - if _, ok := invalidated[mm]; !ok { - if err := mm.InvalidateUnsavable(ctx); err != nil { - return err - } - invalidated[mm] = struct{}{} - } - } - // I really wish we just had a sync.Map of all MMs... - if r, ok := t.runState.(*runSyscallAfterExecStop); ok { - if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil { - return err - } - } - } - return nil -} - +// Preconditions: !VFS2Enabled. func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) { - // TODO(gvisor.dev/issue/1663): Add save support for VFS2. - if VFS2Enabled { - return - } - ts.mu.RLock() defer ts.mu.RUnlock() @@ -644,8 +625,33 @@ func (ts *TaskSet) unregisterEpollWaiters(ctx context.Context) { } } +// Preconditions: The kernel must be paused. +func (k *Kernel) invalidateUnsavableMappings(ctx context.Context) error { + invalidated := make(map[*mm.MemoryManager]struct{}) + k.tasks.mu.RLock() + defer k.tasks.mu.RUnlock() + for t := range k.tasks.Root.tids { + // We can skip locking Task.mu here since the kernel is paused. + if mm := t.tc.MemoryManager; mm != nil { + if _, ok := invalidated[mm]; !ok { + if err := mm.InvalidateUnsavable(ctx); err != nil { + return err + } + invalidated[mm] = struct{}{} + } + } + // I really wish we just had a sync.Map of all MMs... + if r, ok := t.runState.(*runSyscallAfterExecStop); ok { + if err := r.tc.MemoryManager.InvalidateUnsavable(ctx); err != nil { + return err + } + } + } + return nil +} + // LoadFrom returns a new Kernel loaded from args. -func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clocks) error { +func (k *Kernel) LoadFrom(ctx context.Context, r wire.Reader, net inet.Stack, clocks sentrytime.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { loadStart := time.Now() initAppCores := k.applicationCores @@ -656,7 +662,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock // don't need to explicitly install it in the Kernel. cpuidStart := time.Now() var features cpuid.FeatureSet - if _, err := state.Load(k.SupervisorContext(), r, &features); err != nil { + if _, err := state.Load(ctx, r, &features); err != nil { return err } log.Infof("CPUID load took [%s].", time.Since(cpuidStart)) @@ -671,7 +677,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock // Load the kernel state. kernelStart := time.Now() - stats, err := state.Load(k.SupervisorContext(), r, k) + stats, err := state.Load(ctx, r, k) if err != nil { return err } @@ -684,7 +690,7 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock // Load the memory file's state. memoryStart := time.Now() - if err := k.mf.LoadFrom(k.SupervisorContext(), r); err != nil { + if err := k.mf.LoadFrom(ctx, r); err != nil { return err } log.Infof("Memory load took [%s].", time.Since(memoryStart)) @@ -696,11 +702,17 @@ func (k *Kernel) LoadFrom(r wire.Reader, net inet.Stack, clocks sentrytime.Clock net.Resume() } - // Ensure that all pending asynchronous work is complete: - // - namedpipe opening - // - inode file opening - if err := fs.AsyncErrorBarrier(); err != nil { - return err + if VFS2Enabled { + if err := k.vfs.CompleteRestore(ctx, vfsOpts); err != nil { + return err + } + } else { + // Ensure that all pending asynchronous work is complete: + // - namedpipe opening + // - inode file opening + if err := fs.AsyncErrorBarrier(); err != nil { + return err + } } tcpip.AsyncLoading.Wait() diff --git a/pkg/sentry/kernel/pipe/node_test.go b/pkg/sentry/kernel/pipe/node_test.go index ce0db5583..d6fb0fdb8 100644 --- a/pkg/sentry/kernel/pipe/node_test.go +++ b/pkg/sentry/kernel/pipe/node_test.go @@ -22,7 +22,6 @@ import ( "gvisor.dev/gvisor/pkg/sentry/contexttest" "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/syserror" - "gvisor.dev/gvisor/pkg/usermem" ) type sleeper struct { @@ -66,7 +65,8 @@ func testOpenOrDie(ctx context.Context, t *testing.T, n fs.InodeOperations, flag d := fs.NewDirent(ctx, inode, "pipe") file, err := n.GetFile(ctx, d, flags) if err != nil { - t.Fatalf("open with flags %+v failed: %v", flags, err) + t.Errorf("open with flags %+v failed: %v", flags, err) + return nil, err } if doneChan != nil { doneChan <- struct{}{} @@ -85,11 +85,11 @@ func testOpen(ctx context.Context, t *testing.T, n fs.InodeOperations, flags fs. } func newNamedPipe(t *testing.T) *Pipe { - return NewPipe(true, DefaultPipeSize, usermem.PageSize) + return NewPipe(true, DefaultPipeSize) } func newAnonPipe(t *testing.T) *Pipe { - return NewPipe(false, DefaultPipeSize, usermem.PageSize) + return NewPipe(false, DefaultPipeSize) } // assertRecvBlocks ensures that a recv attempt on c blocks for at least diff --git a/pkg/sentry/kernel/pipe/pipe.go b/pkg/sentry/kernel/pipe/pipe.go index 67beb0ad6..b989e14c7 100644 --- a/pkg/sentry/kernel/pipe/pipe.go +++ b/pkg/sentry/kernel/pipe/pipe.go @@ -26,18 +26,27 @@ import ( "gvisor.dev/gvisor/pkg/sentry/fs" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/syserror" + "gvisor.dev/gvisor/pkg/usermem" "gvisor.dev/gvisor/pkg/waiter" ) const ( // MinimumPipeSize is a hard limit of the minimum size of a pipe. - MinimumPipeSize = 64 << 10 + // It corresponds to fs/pipe.c:pipe_min_size. + MinimumPipeSize = usermem.PageSize + + // MaximumPipeSize is a hard limit on the maximum size of a pipe. + // It corresponds to fs/pipe.c:pipe_max_size. + MaximumPipeSize = 1048576 // DefaultPipeSize is the system-wide default size of a pipe in bytes. - DefaultPipeSize = MinimumPipeSize + // It corresponds to pipe_fs_i.h:PIPE_DEF_BUFFERS. + DefaultPipeSize = 16 * usermem.PageSize - // MaximumPipeSize is a hard limit on the maximum size of a pipe. - MaximumPipeSize = 8 << 20 + // atomicIOBytes is the maximum number of bytes that the pipe will + // guarantee atomic reads or writes atomically. + // It corresponds to limits.h:PIPE_BUF. + atomicIOBytes = 4096 ) // Pipe is an encapsulation of a platform-independent pipe. @@ -53,12 +62,6 @@ type Pipe struct { // This value is immutable. isNamed bool - // atomicIOBytes is the maximum number of bytes that the pipe will - // guarantee atomic reads or writes atomically. - // - // This value is immutable. - atomicIOBytes int64 - // The number of active readers for this pipe. // // Access atomically. @@ -94,47 +97,34 @@ type Pipe struct { // NewPipe initializes and returns a pipe. // -// N.B. The size and atomicIOBytes will be bounded. -func NewPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *Pipe { +// N.B. The size will be bounded. +func NewPipe(isNamed bool, sizeBytes int64) *Pipe { if sizeBytes < MinimumPipeSize { sizeBytes = MinimumPipeSize } if sizeBytes > MaximumPipeSize { sizeBytes = MaximumPipeSize } - if atomicIOBytes <= 0 { - atomicIOBytes = 1 - } - if atomicIOBytes > sizeBytes { - atomicIOBytes = sizeBytes - } var p Pipe - initPipe(&p, isNamed, sizeBytes, atomicIOBytes) + initPipe(&p, isNamed, sizeBytes) return &p } -func initPipe(pipe *Pipe, isNamed bool, sizeBytes, atomicIOBytes int64) { +func initPipe(pipe *Pipe, isNamed bool, sizeBytes int64) { if sizeBytes < MinimumPipeSize { sizeBytes = MinimumPipeSize } if sizeBytes > MaximumPipeSize { sizeBytes = MaximumPipeSize } - if atomicIOBytes <= 0 { - atomicIOBytes = 1 - } - if atomicIOBytes > sizeBytes { - atomicIOBytes = sizeBytes - } pipe.isNamed = isNamed pipe.max = sizeBytes - pipe.atomicIOBytes = atomicIOBytes } // NewConnectedPipe initializes a pipe and returns a pair of objects // representing the read and write ends of the pipe. -func NewConnectedPipe(ctx context.Context, sizeBytes, atomicIOBytes int64) (*fs.File, *fs.File) { - p := NewPipe(false /* isNamed */, sizeBytes, atomicIOBytes) +func NewConnectedPipe(ctx context.Context, sizeBytes int64) (*fs.File, *fs.File) { + p := NewPipe(false /* isNamed */, sizeBytes) // Build an fs.Dirent for the pipe which will be shared by both // returned files. @@ -264,7 +254,7 @@ func (p *Pipe) writeLocked(ctx context.Context, ops writeOps) (int64, error) { wanted := ops.left() avail := p.max - p.view.Size() if wanted > avail { - if wanted <= p.atomicIOBytes { + if wanted <= atomicIOBytes { return 0, syserror.ErrWouldBlock } ops.limit(avail) diff --git a/pkg/sentry/kernel/pipe/pipe_test.go b/pkg/sentry/kernel/pipe/pipe_test.go index fe97e9800..3dd739080 100644 --- a/pkg/sentry/kernel/pipe/pipe_test.go +++ b/pkg/sentry/kernel/pipe/pipe_test.go @@ -26,7 +26,7 @@ import ( func TestPipeRW(t *testing.T) { ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, 65536, 4096) + r, w := NewConnectedPipe(ctx, 65536) defer r.DecRef(ctx) defer w.DecRef(ctx) @@ -46,7 +46,7 @@ func TestPipeRW(t *testing.T) { func TestPipeReadBlock(t *testing.T) { ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, 65536, 4096) + r, w := NewConnectedPipe(ctx, 65536) defer r.DecRef(ctx) defer w.DecRef(ctx) @@ -61,7 +61,7 @@ func TestPipeWriteBlock(t *testing.T) { const capacity = MinimumPipeSize ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, capacity, atomicIOBytes) + r, w := NewConnectedPipe(ctx, capacity) defer r.DecRef(ctx) defer w.DecRef(ctx) @@ -76,7 +76,7 @@ func TestPipeWriteUntilEnd(t *testing.T) { const atomicIOBytes = 2 ctx := contexttest.Context(t) - r, w := NewConnectedPipe(ctx, atomicIOBytes, atomicIOBytes) + r, w := NewConnectedPipe(ctx, atomicIOBytes) defer r.DecRef(ctx) defer w.DecRef(ctx) @@ -116,7 +116,8 @@ func TestPipeWriteUntilEnd(t *testing.T) { } } if err != nil { - t.Fatalf("Readv: got unexpected error %v", err) + t.Errorf("Readv: got unexpected error %v", err) + return } } }() diff --git a/pkg/sentry/kernel/pipe/vfs.go b/pkg/sentry/kernel/pipe/vfs.go index 1a152142b..7b23cbe86 100644 --- a/pkg/sentry/kernel/pipe/vfs.go +++ b/pkg/sentry/kernel/pipe/vfs.go @@ -33,6 +33,8 @@ import ( // VFSPipe represents the actual pipe, analagous to an inode. VFSPipes should // not be copied. +// +// +stateify savable type VFSPipe struct { // mu protects the fields below. mu sync.Mutex `state:"nosave"` @@ -52,9 +54,9 @@ type VFSPipe struct { } // NewVFSPipe returns an initialized VFSPipe. -func NewVFSPipe(isNamed bool, sizeBytes, atomicIOBytes int64) *VFSPipe { +func NewVFSPipe(isNamed bool, sizeBytes int64) *VFSPipe { var vp VFSPipe - initPipe(&vp.pipe, isNamed, sizeBytes, atomicIOBytes) + initPipe(&vp.pipe, isNamed, sizeBytes) return &vp } @@ -164,6 +166,8 @@ func (vp *VFSPipe) newFD(mnt *vfs.Mount, vfsd *vfs.Dentry, statusFlags uint32, l // VFSPipeFD implements vfs.FileDescriptionImpl for pipes. It also implements // non-atomic usermem.IO methods, allowing it to be passed as usermem.IO to // other FileDescriptions for splice(2) and tee(2). +// +// +stateify savable type VFSPipeFD struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl diff --git a/pkg/sentry/kernel/ptrace.go b/pkg/sentry/kernel/ptrace.go index 1145faf13..1abfe2201 100644 --- a/pkg/sentry/kernel/ptrace.go +++ b/pkg/sentry/kernel/ptrace.go @@ -1000,7 +1000,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { // at the address specified by the data parameter, and the return value // is the error flag." - ptrace(2) word := t.Arch().Native(0) - if _, err := word.CopyIn(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr); err != nil { + if _, err := word.CopyIn(target.CopyContext(t, usermem.IOOpts{IgnorePermissions: true}), addr); err != nil { return err } _, err := word.CopyOut(t, data) @@ -1008,7 +1008,7 @@ func (t *Task) Ptrace(req int64, pid ThreadID, addr, data usermem.Addr) error { case linux.PTRACE_POKETEXT, linux.PTRACE_POKEDATA: word := t.Arch().Native(uintptr(data)) - _, err := word.CopyOut(target.AsCopyContext(usermem.IOOpts{IgnorePermissions: true}), addr) + _, err := word.CopyOut(target.CopyContext(t, usermem.IOOpts{IgnorePermissions: true}), addr) return err case linux.PTRACE_GETREGSET: diff --git a/pkg/sentry/kernel/semaphore/semaphore.go b/pkg/sentry/kernel/semaphore/semaphore.go index c00fa1138..b99c0bffa 100644 --- a/pkg/sentry/kernel/semaphore/semaphore.go +++ b/pkg/sentry/kernel/semaphore/semaphore.go @@ -103,6 +103,7 @@ type waiter struct { waiterEntry // value represents how much resource the waiter needs to wake up. + // The value is either 0 or negative. value int16 ch chan struct{} } @@ -283,6 +284,33 @@ func (s *Set) Change(ctx context.Context, creds *auth.Credentials, owner fs.File return nil } +// GetStat extracts semid_ds information from the set. +func (s *Set) GetStat(creds *auth.Credentials) (*linux.SemidDS, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // "The calling process must have read permission on the semaphore set." + if !s.checkPerms(creds, fs.PermMask{Read: true}) { + return nil, syserror.EACCES + } + + ds := &linux.SemidDS{ + SemPerm: linux.IPCPerm{ + Key: uint32(s.key), + UID: uint32(creds.UserNamespace.MapFromKUID(s.owner.UID)), + GID: uint32(creds.UserNamespace.MapFromKGID(s.owner.GID)), + CUID: uint32(creds.UserNamespace.MapFromKUID(s.creator.UID)), + CGID: uint32(creds.UserNamespace.MapFromKGID(s.creator.GID)), + Mode: uint16(s.perms.LinuxMode()), + Seq: 0, // IPC sequence not supported. + }, + SemOTime: s.opTime.TimeT(), + SemCTime: s.changeTime.TimeT(), + SemNSems: uint64(s.Size()), + } + return ds, nil +} + // SetVal overrides a semaphore value, waking up waiters as needed. func (s *Set) SetVal(ctx context.Context, num int32, val int16, creds *auth.Credentials, pid int32) error { if val < 0 || val > valueMax { @@ -320,7 +348,7 @@ func (s *Set) SetValAll(ctx context.Context, vals []uint16, creds *auth.Credenti } for _, val := range vals { - if val < 0 || val > valueMax { + if val > valueMax { return syserror.ERANGE } } @@ -396,6 +424,42 @@ func (s *Set) GetPID(num int32, creds *auth.Credentials) (int32, error) { return sem.pid, nil } +func (s *Set) countWaiters(num int32, creds *auth.Credentials, pred func(w *waiter) bool) (uint16, error) { + s.mu.Lock() + defer s.mu.Unlock() + + // The calling process must have read permission on the semaphore set. + if !s.checkPerms(creds, fs.PermMask{Read: true}) { + return 0, syserror.EACCES + } + + sem := s.findSem(num) + if sem == nil { + return 0, syserror.ERANGE + } + var cnt uint16 + for w := sem.waiters.Front(); w != nil; w = w.Next() { + if pred(w) { + cnt++ + } + } + return cnt, nil +} + +// CountZeroWaiters returns number of waiters waiting for the sem's value to increase. +func (s *Set) CountZeroWaiters(num int32, creds *auth.Credentials) (uint16, error) { + return s.countWaiters(num, creds, func(w *waiter) bool { + return w.value == 0 + }) +} + +// CountNegativeWaiters returns number of waiters waiting for the sem to go to zero. +func (s *Set) CountNegativeWaiters(num int32, creds *auth.Credentials) (uint16, error) { + return s.countWaiters(num, creds, func(w *waiter) bool { + return w.value < 0 + }) +} + // ExecuteOps attempts to execute a list of operations to the set. It only // succeeds when all operations can be applied. No changes are made if it fails. // @@ -548,11 +612,18 @@ func (s *Set) destroy() { } } +func abs(val int16) int16 { + if val < 0 { + return -val + } + return val +} + // wakeWaiters goes over all waiters and checks which of them can be notified. func (s *sem) wakeWaiters() { // Note that this will release all waiters waiting for 0 too. for w := s.waiters.Front(); w != nil; { - if s.value < w.value { + if s.value < abs(w.value) { // Still blocked, skip it. w = w.Next() continue diff --git a/pkg/sentry/kernel/sessions.go b/pkg/sentry/kernel/sessions.go index df5c8421b..5bddb0a36 100644 --- a/pkg/sentry/kernel/sessions.go +++ b/pkg/sentry/kernel/sessions.go @@ -477,20 +477,20 @@ func (tg *ThreadGroup) Session() *Session { // // If this group isn't visible in this namespace, zero will be returned. It is // the callers responsibility to check that before using this function. -func (pidns *PIDNamespace) IDOfSession(s *Session) SessionID { - pidns.owner.mu.RLock() - defer pidns.owner.mu.RUnlock() - return pidns.sids[s] +func (ns *PIDNamespace) IDOfSession(s *Session) SessionID { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return ns.sids[s] } // SessionWithID returns the Session with the given ID in the PID namespace ns, // or nil if that given ID is not defined in this namespace. // // A reference is not taken on the session. -func (pidns *PIDNamespace) SessionWithID(id SessionID) *Session { - pidns.owner.mu.RLock() - defer pidns.owner.mu.RUnlock() - return pidns.sessions[id] +func (ns *PIDNamespace) SessionWithID(id SessionID) *Session { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return ns.sessions[id] } // ProcessGroup returns the ThreadGroup's ProcessGroup. @@ -505,18 +505,18 @@ func (tg *ThreadGroup) ProcessGroup() *ProcessGroup { // IDOfProcessGroup returns the process group assigned to pg in PID namespace ns. // // The same constraints apply as IDOfSession. -func (pidns *PIDNamespace) IDOfProcessGroup(pg *ProcessGroup) ProcessGroupID { - pidns.owner.mu.RLock() - defer pidns.owner.mu.RUnlock() - return pidns.pgids[pg] +func (ns *PIDNamespace) IDOfProcessGroup(pg *ProcessGroup) ProcessGroupID { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return ns.pgids[pg] } // ProcessGroupWithID returns the ProcessGroup with the given ID in the PID // namespace ns, or nil if that given ID is not defined in this namespace. // // A reference is not taken on the process group. -func (pidns *PIDNamespace) ProcessGroupWithID(id ProcessGroupID) *ProcessGroup { - pidns.owner.mu.RLock() - defer pidns.owner.mu.RUnlock() - return pidns.processGroups[id] +func (ns *PIDNamespace) ProcessGroupWithID(id ProcessGroupID) *ProcessGroup { + ns.owner.mu.RLock() + defer ns.owner.mu.RUnlock() + return ns.processGroups[id] } diff --git a/pkg/sentry/kernel/shm/BUILD b/pkg/sentry/kernel/shm/BUILD index f8a382fd8..80a592c8f 100644 --- a/pkg/sentry/kernel/shm/BUILD +++ b/pkg/sentry/kernel/shm/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "shm_refs.go", package = "shm", prefix = "Shm", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "Shm", }, @@ -27,7 +27,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", - "//pkg/refs_vfs2", + "//pkg/refsvfs2", "//pkg/sentry/device", "//pkg/sentry/fs", "//pkg/sentry/kernel/auth", diff --git a/pkg/sentry/kernel/task_clone.go b/pkg/sentry/kernel/task_clone.go index 682080c14..527344162 100644 --- a/pkg/sentry/kernel/task_clone.go +++ b/pkg/sentry/kernel/task_clone.go @@ -355,7 +355,7 @@ func (t *Task) Clone(opts *CloneOptions) (ThreadID, *SyscallControl, error) { } if opts.ChildSetTID { ctid := nt.ThreadID() - ctid.CopyOut(nt.AsCopyContext(usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID) + ctid.CopyOut(nt.CopyContext(t, usermem.IOOpts{AddressSpaceActive: false}), opts.ChildTID) } ntid := t.tg.pidns.IDOfTask(nt) if opts.ParentSetTID { diff --git a/pkg/sentry/kernel/task_usermem.go b/pkg/sentry/kernel/task_usermem.go index ce134bf54..94dabbcd8 100644 --- a/pkg/sentry/kernel/task_usermem.go +++ b/pkg/sentry/kernel/task_usermem.go @@ -18,7 +18,8 @@ import ( "math" "gvisor.dev/gvisor/pkg/abi/linux" - "gvisor.dev/gvisor/pkg/marshal" + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/sentry/mm" "gvisor.dev/gvisor/pkg/syserror" "gvisor.dev/gvisor/pkg/usermem" ) @@ -281,29 +282,89 @@ func (t *Task) IovecsIOSequence(addr usermem.Addr, iovcnt int, opts usermem.IOOp }, nil } -// copyContext implements marshal.CopyContext. It wraps a task to allow copying -// memory to and from the task memory with custom usermem.IOOpts. -type copyContext struct { - *Task +type taskCopyContext struct { + ctx context.Context + t *Task opts usermem.IOOpts } -// AsCopyContext wraps the task and returns it as CopyContext. -func (t *Task) AsCopyContext(opts usermem.IOOpts) marshal.CopyContext { - return ©Context{t, opts} +// CopyContext returns a marshal.CopyContext that copies to/from t's address +// space using opts. +func (t *Task) CopyContext(ctx context.Context, opts usermem.IOOpts) *taskCopyContext { + return &taskCopyContext{ + ctx: ctx, + t: t, + opts: opts, + } +} + +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (cc *taskCopyContext) CopyScratchBuffer(size int) []byte { + if ctxTask, ok := cc.ctx.(*Task); ok { + return ctxTask.CopyScratchBuffer(size) + } + return make([]byte, size) +} + +func (cc *taskCopyContext) getMemoryManager() (*mm.MemoryManager, error) { + cc.t.mu.Lock() + tmm := cc.t.MemoryManager() + cc.t.mu.Unlock() + if !tmm.IncUsers() { + return nil, syserror.EFAULT + } + return tmm, nil +} + +// CopyInBytes implements marshal.CopyContext.CopyInBytes. +func (cc *taskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { + tmm, err := cc.getMemoryManager() + if err != nil { + return 0, err + } + defer tmm.DecUsers(cc.ctx) + return tmm.CopyIn(cc.ctx, addr, dst, cc.opts) +} + +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. +func (cc *taskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { + tmm, err := cc.getMemoryManager() + if err != nil { + return 0, err + } + defer tmm.DecUsers(cc.ctx) + return tmm.CopyOut(cc.ctx, addr, src, cc.opts) +} + +type ownTaskCopyContext struct { + t *Task + opts usermem.IOOpts +} + +// OwnCopyContext returns a marshal.CopyContext that copies to/from t's address +// space using opts. The returned CopyContext may only be used by t's task +// goroutine. +// +// Since t already implements marshal.CopyContext, this is only needed to +// override the usermem.IOOpts used for the copy. +func (t *Task) OwnCopyContext(opts usermem.IOOpts) *ownTaskCopyContext { + return &ownTaskCopyContext{ + t: t, + opts: opts, + } } -// CopyInString copies a string in from the task's memory. -func (t *copyContext) CopyInString(addr usermem.Addr, maxLen int) (string, error) { - return usermem.CopyStringIn(t, t.MemoryManager(), addr, maxLen, t.opts) +// CopyScratchBuffer implements marshal.CopyContext.CopyScratchBuffer. +func (cc *ownTaskCopyContext) CopyScratchBuffer(size int) []byte { + return cc.t.CopyScratchBuffer(size) } -// CopyInBytes copies task memory into dst from an IO context. -func (t *copyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { - return t.MemoryManager().CopyIn(t, addr, dst, t.opts) +// CopyInBytes implements marshal.CopyContext.CopyInBytes. +func (cc *ownTaskCopyContext) CopyInBytes(addr usermem.Addr, dst []byte) (int, error) { + return cc.t.MemoryManager().CopyIn(cc.t, addr, dst, cc.opts) } -// CopyOutBytes copies src into task memoryfrom an IO context. -func (t *copyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { - return t.MemoryManager().CopyOut(t, addr, src, t.opts) +// CopyOutBytes implements marshal.CopyContext.CopyOutBytes. +func (cc *ownTaskCopyContext) CopyOutBytes(addr usermem.Addr, src []byte) (int, error) { + return cc.t.MemoryManager().CopyOut(cc.t, addr, src, cc.opts) } diff --git a/pkg/sentry/kernel/vdso.go b/pkg/sentry/kernel/vdso.go index 9bc452e67..9e5c2d26f 100644 --- a/pkg/sentry/kernel/vdso.go +++ b/pkg/sentry/kernel/vdso.go @@ -115,7 +115,7 @@ func (v *VDSOParamPage) incrementSeq(paramPage safemem.Block) error { } if old != v.seq { - return fmt.Errorf("unexpected VDSOParamPage seq value: got %d expected %d. Application may hang or get incorrect time from the VDSO.", old, v.seq) + return fmt.Errorf("unexpected VDSOParamPage seq value: got %d expected %d; application may hang or get incorrect time from the VDSO", old, v.seq) } v.seq = next diff --git a/pkg/sentry/mm/BUILD b/pkg/sentry/mm/BUILD index b4a47ccca..6dbeccfe2 100644 --- a/pkg/sentry/mm/BUILD +++ b/pkg/sentry/mm/BUILD @@ -78,7 +78,7 @@ go_template_instance( out = "aio_mappable_refs.go", package = "mm", prefix = "aioMappable", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "aioMappable", }, @@ -89,7 +89,7 @@ go_template_instance( out = "special_mappable_refs.go", package = "mm", prefix = "SpecialMappable", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "SpecialMappable", }, @@ -127,6 +127,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safecopy", "//pkg/safemem", "//pkg/sentry/arch", diff --git a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go index 0a54dd30d..acad4c793 100644 --- a/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_amd64_unsafe.go @@ -79,6 +79,18 @@ func bluepillStopGuest(c *vCPU) { c.runData.requestInterruptWindow = 0 } +// bluepillSigBus is reponsible for injecting NMI to trigger sigbus. +// +//go:nosplit +func bluepillSigBus(c *vCPU) { + if _, _, errno := syscall.RawSyscall( // escapes: no. + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_NMI, 0); errno != 0 { + throw("NMI injection failed") + } +} + // bluepillReadyStopGuest checks whether the current vCPU is ready for interrupt injection. // //go:nosplit diff --git a/pkg/sentry/platform/kvm/bluepill_arm64.go b/pkg/sentry/platform/kvm/bluepill_arm64.go index 58f3d6fdd..965ad66b5 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64.go @@ -27,15 +27,20 @@ var ( // The action for bluepillSignal is changed by sigaction(). bluepillSignal = syscall.SIGILL - // vcpuSErr is the event of system error. - vcpuSErr = kvmVcpuEvents{ + // vcpuSErrBounce is the event of system error for bouncing KVM. + vcpuSErrBounce = kvmVcpuEvents{ exception: exception{ sErrPending: 1, - sErrHasEsr: 0, - pad: [6]uint8{0, 0, 0, 0, 0, 0}, - sErrEsr: 1, }, - rsvd: [12]uint32{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + } + + // vcpuSErrNMI is the event of system error to trigger sigbus. + vcpuSErrNMI = kvmVcpuEvents{ + exception: exception{ + sErrPending: 1, + sErrHasEsr: 1, + sErrEsr: _ESR_ELx_SERR_NMI, + }, } ) diff --git a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go index b35c930e2..9433d4da5 100644 --- a/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_arm64_unsafe.go @@ -80,11 +80,24 @@ func getHypercallID(addr uintptr) int { // //go:nosplit func bluepillStopGuest(c *vCPU) { - if _, _, errno := syscall.RawSyscall( + if _, _, errno := syscall.RawSyscall( // escapes: no. syscall.SYS_IOCTL, uintptr(c.fd), _KVM_SET_VCPU_EVENTS, - uintptr(unsafe.Pointer(&vcpuSErr))); errno != 0 { + uintptr(unsafe.Pointer(&vcpuSErrBounce))); errno != 0 { + throw("sErr injection failed") + } +} + +// bluepillSigBus is reponsible for injecting sError to trigger sigbus. +// +//go:nosplit +func bluepillSigBus(c *vCPU) { + if _, _, errno := syscall.RawSyscall( // escapes: no. + syscall.SYS_IOCTL, + uintptr(c.fd), + _KVM_SET_VCPU_EVENTS, + uintptr(unsafe.Pointer(&vcpuSErrNMI))); errno != 0 { throw("sErr injection failed") } } diff --git a/pkg/sentry/platform/kvm/bluepill_unsafe.go b/pkg/sentry/platform/kvm/bluepill_unsafe.go index eb05950cd..75085ac6a 100644 --- a/pkg/sentry/platform/kvm/bluepill_unsafe.go +++ b/pkg/sentry/platform/kvm/bluepill_unsafe.go @@ -146,12 +146,7 @@ func bluepillHandler(context unsafe.Pointer) { // MMIO exit we receive EFAULT from the run ioctl. We // always inject an NMI here since we may be in kernel // mode and have interrupts disabled. - if _, _, errno := syscall.RawSyscall( // escapes: no. - syscall.SYS_IOCTL, - uintptr(c.fd), - _KVM_NMI, 0); errno != 0 { - throw("NMI injection failed") - } + bluepillSigBus(c) continue // Rerun vCPU. default: throw("run failed") diff --git a/pkg/sentry/platform/kvm/kvm.go b/pkg/sentry/platform/kvm/kvm.go index dd45ad10b..5979aef97 100644 --- a/pkg/sentry/platform/kvm/kvm.go +++ b/pkg/sentry/platform/kvm/kvm.go @@ -158,8 +158,7 @@ func (*KVM) MaxUserAddress() usermem.Addr { // NewAddressSpace returns a new pagetable root. func (k *KVM) NewAddressSpace(_ interface{}) (platform.AddressSpace, <-chan struct{}, error) { // Allocate page tables and install system mappings. - pageTables := pagetables.New(newAllocator()) - k.machine.mapUpperHalf(pageTables) + pageTables := pagetables.NewWithUpper(newAllocator(), k.machine.upperSharedPageTables, ring0.KernelStartAddress) // Return the new address space. return &addressSpace{ diff --git a/pkg/sentry/platform/kvm/kvm_const_arm64.go b/pkg/sentry/platform/kvm/kvm_const_arm64.go index 84df0f878..b060d9544 100644 --- a/pkg/sentry/platform/kvm/kvm_const_arm64.go +++ b/pkg/sentry/platform/kvm/kvm_const_arm64.go @@ -38,6 +38,8 @@ const ( _KVM_ARM64_REGS_SCTLR_EL1 = 0x603000000013c080 _KVM_ARM64_REGS_CPACR_EL1 = 0x603000000013c082 _KVM_ARM64_REGS_VBAR_EL1 = 0x603000000013c600 + _KVM_ARM64_REGS_TIMER_CNT = 0x603000000013df1a + _KVM_ARM64_REGS_CNTFRQ_EL0 = 0x603000000013df00 ) // Arm64: Architectural Feature Access Control Register EL1. @@ -149,6 +151,9 @@ const ( _ESR_SEGV_PEMERR_L1 = 0xd _ESR_SEGV_PEMERR_L2 = 0xe _ESR_SEGV_PEMERR_L3 = 0xf + + // Custom ISS field definitions for system error. + _ESR_ELx_SERR_NMI = 0x1 ) // Arm64: MMIO base address used to dispatch hypercalls. diff --git a/pkg/sentry/platform/kvm/machine.go b/pkg/sentry/platform/kvm/machine.go index 61ed24d01..e2fffc99b 100644 --- a/pkg/sentry/platform/kvm/machine.go +++ b/pkg/sentry/platform/kvm/machine.go @@ -25,6 +25,7 @@ import ( "gvisor.dev/gvisor/pkg/procid" "gvisor.dev/gvisor/pkg/sentry/platform/ring0" "gvisor.dev/gvisor/pkg/sentry/platform/ring0/pagetables" + ktime "gvisor.dev/gvisor/pkg/sentry/time" "gvisor.dev/gvisor/pkg/sync" "gvisor.dev/gvisor/pkg/usermem" ) @@ -40,6 +41,9 @@ type machine struct { // slots are currently being updated, and the caller should retry. nextSlot uint32 + // upperSharedPageTables tracks the read-only shared upper of all the pagetables. + upperSharedPageTables *pagetables.PageTables + // kernel is the set of global structures. kernel ring0.Kernel @@ -198,9 +202,7 @@ func newMachine(vm int) (*machine, error) { log.Debugf("The maximum number of vCPUs is %d.", m.maxVCPUs) m.vCPUsByTID = make(map[uint64]*vCPU) m.vCPUsByID = make([]*vCPU, m.maxVCPUs) - m.kernel.Init(ring0.KernelOpts{ - PageTables: pagetables.New(newAllocator()), - }, m.maxVCPUs) + m.kernel.Init(m.maxVCPUs) // Pull the maximum slots. maxSlots, _, errno := syscall.RawSyscall(syscall.SYS_IOCTL, uintptr(m.fd), _KVM_CHECK_EXTENSION, _KVM_CAP_MAX_MEMSLOTS) @@ -212,6 +214,13 @@ func newMachine(vm int) (*machine, error) { log.Debugf("The maximum number of slots is %d.", m.maxSlots) m.usedSlots = make([]uintptr, m.maxSlots) + // Create the upper shared pagetables and kernel(sentry) pagetables. + m.upperSharedPageTables = pagetables.New(newAllocator()) + m.mapUpperHalf(m.upperSharedPageTables) + m.upperSharedPageTables.Allocator.(*allocator).base.Drain() + m.upperSharedPageTables.MarkReadOnlyShared() + m.kernel.PageTables = pagetables.NewWithUpper(newAllocator(), m.upperSharedPageTables, ring0.KernelStartAddress) + // Apply the physical mappings. Note that these mappings may point to // guest physical addresses that are not actually available. These // physical pages are mapped on demand, see kernel_unsafe.go. @@ -225,7 +234,6 @@ func newMachine(vm int) (*machine, error) { return true // Keep iterating. }) - m.mapUpperHalf(m.kernel.PageTables) var physicalRegionsReadOnly []physicalRegion var physicalRegionsAvailable []physicalRegion @@ -625,3 +633,35 @@ func (c *vCPU) BounceToKernel() { func (c *vCPU) BounceToHost() { c.bounce(true) } + +// setSystemTimeLegacy calibrates and sets an approximate system time. +func (c *vCPU) setSystemTimeLegacy() error { + const minIterations = 10 + minimum := uint64(0) + for iter := 0; ; iter++ { + // Try to set the TSC to an estimate of where it will be + // on the host during a "fast" system call iteration. + start := uint64(ktime.Rdtsc()) + if err := c.setTSC(start + (minimum / 2)); err != nil { + return err + } + // See if this is our new minimum call time. Note that this + // serves two functions: one, we make sure that we are + // accurately predicting the offset we need to set. Second, we + // don't want to do the final set on a slow call, which could + // produce a really bad result. + end := uint64(ktime.Rdtsc()) + if end < start { + continue // Totally bogus: unstable TSC? + } + current := end - start + if current < minimum || iter == 0 { + minimum = current // Set our new minimum. + } + // Is this past minIterations and within ~10% of minimum? + upperThreshold := (((minimum << 3) + minimum) >> 3) + if iter >= minIterations && current <= upperThreshold { + return nil + } + } +} diff --git a/pkg/sentry/platform/kvm/machine_amd64.go b/pkg/sentry/platform/kvm/machine_amd64.go index c67127d95..8e03c310d 100644 --- a/pkg/sentry/platform/kvm/machine_amd64.go +++ b/pkg/sentry/platform/kvm/machine_amd64.go @@ -252,38 +252,6 @@ func (c *vCPU) setSystemTime() error { } } -// setSystemTimeLegacy calibrates and sets an approximate system time. -func (c *vCPU) setSystemTimeLegacy() error { - const minIterations = 10 - minimum := uint64(0) - for iter := 0; ; iter++ { - // Try to set the TSC to an estimate of where it will be - // on the host during a "fast" system call iteration. - start := uint64(ktime.Rdtsc()) - if err := c.setTSC(start + (minimum / 2)); err != nil { - return err - } - // See if this is our new minimum call time. Note that this - // serves two functions: one, we make sure that we are - // accurately predicting the offset we need to set. Second, we - // don't want to do the final set on a slow call, which could - // produce a really bad result. - end := uint64(ktime.Rdtsc()) - if end < start { - continue // Totally bogus: unstable TSC? - } - current := end - start - if current < minimum || iter == 0 { - minimum = current // Set our new minimum. - } - // Is this past minIterations and within ~10% of minimum? - upperThreshold := (((minimum << 3) + minimum) >> 3) - if iter >= minIterations && current <= upperThreshold { - return nil - } - } -} - // nonCanonical generates a canonical address return. // //go:nosplit @@ -464,30 +432,27 @@ func availableRegionsForSetMem() (phyRegions []physicalRegion) { return physicalRegions } -var execRegions = func() (regions []region) { +func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { + // Map all the executible regions so that all the entry functions + // are mapped in the upper half. applyVirtualRegions(func(vr virtualRegion) { if excludeVirtualRegion(vr) || vr.filename == "[vsyscall]" { return } + if vr.accessType.Execute { - regions = append(regions, vr.region) + r := vr.region + physical, length, ok := translateToPhysical(r.virtual) + if !ok || length < r.length { + panic("impossible translation") + } + pageTable.Map( + usermem.Addr(ring0.KernelStartAddress|r.virtual), + r.length, + pagetables.MapOpts{AccessType: usermem.Execute}, + physical) } }) - return -}() - -func (m *machine) mapUpperHalf(pageTable *pagetables.PageTables) { - for _, r := range execRegions { - physical, length, ok := translateToPhysical(r.virtual) - if !ok || length < r.length { - panic("impossilbe translation") - } - pageTable.Map( - usermem.Addr(ring0.KernelStartAddress|r.virtual), - r.length, - pagetables.MapOpts{AccessType: usermem.Execute}, - physical) - } for start, end := range m.kernel.EntryRegions() { regionLen := end - start physical, length, ok := translateToPhysical(start) diff --git a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go index a163f956d..fd92c3873 100644 --- a/pkg/sentry/platform/kvm/machine_arm64_unsafe.go +++ b/pkg/sentry/platform/kvm/machine_arm64_unsafe.go @@ -159,9 +159,33 @@ func (c *vCPU) initArchState() error { } c.floatingPointState = arch.NewFloatingPointData() + + return c.setSystemTime() +} + +// setTSC sets the counter Virtual Offset. +func (c *vCPU) setTSC(value uint64) error { + var ( + reg kvmOneReg + data uint64 + ) + + reg.addr = uint64(reflect.ValueOf(&data).Pointer()) + reg.id = _KVM_ARM64_REGS_TIMER_CNT + data = uint64(value) + + if err := c.setOneRegister(®); err != nil { + return err + } + return nil } +// setSystemTime sets the vCPU to the system time. +func (c *vCPU) setSystemTime() error { + return c.setSystemTimeLegacy() +} + //go:nosplit func (c *vCPU) loadSegments(tid uint64) { // TODO(gvisor.dev/issue/1238): TLS is not supported. @@ -197,7 +221,7 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) if regs := switchOpts.Registers; !ring0.IsCanonical(regs.Pc) { return nonCanonical(regs.Pc, int32(syscall.SIGSEGV), info) } else if !ring0.IsCanonical(regs.Sp) { - return nonCanonical(regs.Sp, int32(syscall.SIGBUS), info) + return nonCanonical(regs.Sp, int32(syscall.SIGSEGV), info) } // Assign PCIDs. @@ -233,10 +257,13 @@ func (c *vCPU) SwitchToUser(switchOpts ring0.SwitchOpts, info *arch.SignalInfo) case ring0.PageFault: return c.fault(int32(syscall.SIGSEGV), info) + case ring0.El0ErrNMI: + return c.fault(int32(syscall.SIGBUS), info) case ring0.Vector(bounce): // ring0.VirtualizationException return usermem.NoAccess, platform.ErrContextInterrupt - case ring0.El0Sync_undef, - ring0.El1Sync_undef: + case ring0.El0SyncUndef: + return c.fault(int32(syscall.SIGILL), info) + case ring0.El1SyncUndef: *info = arch.SignalInfo{ Signo: int32(syscall.SIGILL), Code: 1, // ILL_ILLOPC (illegal opcode). diff --git a/pkg/sentry/platform/ring0/aarch64.go b/pkg/sentry/platform/ring0/aarch64.go index 87a573cc4..327d48465 100644 --- a/pkg/sentry/platform/ring0/aarch64.go +++ b/pkg/sentry/platform/ring0/aarch64.go @@ -58,46 +58,55 @@ type Vector uintptr // Exception vectors. const ( - El1SyncInvalid = iota - El1IrqInvalid - El1FiqInvalid - El1ErrorInvalid + El1InvSync = iota + El1InvIrq + El1InvFiq + El1InvError + El1Sync El1Irq El1Fiq - El1Error + El1Err + El0Sync El0Irq El0Fiq - El0Error - El0Sync_invalid - El0Irq_invalid - El0Fiq_invalid - El0Error_invalid - El1Sync_da - El1Sync_ia - El1Sync_sp_pc - El1Sync_undef - El1Sync_dbg - El1Sync_inv - El0Sync_svc - El0Sync_da - El0Sync_ia - El0Sync_fpsimd_acc - El0Sync_sve_acc - El0Sync_sys - El0Sync_sp_pc - El0Sync_undef - El0Sync_dbg - El0Sync_inv + El0Err + + El0InvSync + El0InvIrq + El0InvFiq + El0InvErr + + El1SyncDa + El1SyncIa + El1SyncSpPc + El1SyncUndef + El1SyncDbg + El1SyncInv + + El0SyncSVC + El0SyncDa + El0SyncIa + El0SyncFpsimdAcc + El0SyncSveAcc + El0SyncSys + El0SyncSpPc + El0SyncUndef + El0SyncDbg + El0SyncInv + + El0ErrNMI + El0ErrBounce + _NR_INTERRUPTS ) // System call vectors. const ( - Syscall Vector = El0Sync_svc - PageFault Vector = El0Sync_da - VirtualizationException Vector = El0Error + Syscall Vector = El0SyncSVC + PageFault Vector = El0SyncDa + VirtualizationException Vector = El0ErrBounce ) // VirtualAddressBits returns the number bits available for virtual addresses. diff --git a/pkg/sentry/platform/ring0/defs.go b/pkg/sentry/platform/ring0/defs.go index e6daf24df..f9765771e 100644 --- a/pkg/sentry/platform/ring0/defs.go +++ b/pkg/sentry/platform/ring0/defs.go @@ -23,6 +23,9 @@ import ( // // This contains global state, shared by multiple CPUs. type Kernel struct { + // PageTables are the kernel pagetables; this must be provided. + PageTables *pagetables.PageTables + KernelArchState } diff --git a/pkg/sentry/platform/ring0/defs_amd64.go b/pkg/sentry/platform/ring0/defs_amd64.go index 00899273e..7a2275558 100644 --- a/pkg/sentry/platform/ring0/defs_amd64.go +++ b/pkg/sentry/platform/ring0/defs_amd64.go @@ -66,17 +66,9 @@ var ( KernelDataSegment SegmentDescriptor ) -// KernelOpts has initialization options for the kernel. -type KernelOpts struct { - // PageTables are the kernel pagetables; this must be provided. - PageTables *pagetables.PageTables -} - // KernelArchState contains architecture-specific state. type KernelArchState struct { - KernelOpts - - // cpuEntries is array of kernelEntry for all cpus + // cpuEntries is array of kernelEntry for all cpus. cpuEntries []kernelEntry // globalIDT is our set of interrupt gates. diff --git a/pkg/sentry/platform/ring0/defs_arm64.go b/pkg/sentry/platform/ring0/defs_arm64.go index 508236e46..a014dcbc0 100644 --- a/pkg/sentry/platform/ring0/defs_arm64.go +++ b/pkg/sentry/platform/ring0/defs_arm64.go @@ -32,15 +32,8 @@ var ( KernelStartAddress = ^uintptr(0) - (UserspaceSize - 1) ) -// KernelOpts has initialization options for the kernel. -type KernelOpts struct { - // PageTables are the kernel pagetables; this must be provided. - PageTables *pagetables.PageTables -} - // KernelArchState contains architecture-specific state. type KernelArchState struct { - KernelOpts } // CPUArchState contains CPU-specific arch state. diff --git a/pkg/sentry/platform/ring0/entry_arm64.s b/pkg/sentry/platform/ring0/entry_arm64.s index 2370a9276..f489ad352 100644 --- a/pkg/sentry/platform/ring0/entry_arm64.s +++ b/pkg/sentry/platform/ring0/entry_arm64.s @@ -288,6 +288,10 @@ #define ESR_ELx_WFx_ISS_WFE (UL(1) << 0) #define ESR_ELx_xVC_IMM_MASK ((1UL << 16) - 1) +/* ISS field definitions for system error */ +#define ESR_ELx_SERR_MASK (0x1) +#define ESR_ELx_SERR_NMI (0x1) + // LOAD_KERNEL_ADDRESS loads a kernel address. #define LOAD_KERNEL_ADDRESS(from, to) \ MOVD from, to; \ @@ -366,6 +370,19 @@ MOVD R4, CPU_REGISTERS+PTRACE_SP(RSV_REG); \ LOAD_KERNEL_STACK(RSV_REG); // Load the temporary stack. +// EXCEPTION_WITH_ERROR is a common exception handler function. +#define EXCEPTION_WITH_ERROR(user, vector) \ + WORD $0xd538d092; \ //MRS TPIDR_EL1, R18 + WORD $0xd538601a; \ //MRS FAR_EL1, R26 + MOVD R26, CPU_FAULT_ADDR(RSV_REG); \ + MOVD $user, R3; \ + MOVD R3, CPU_ERROR_TYPE(RSV_REG); \ // Set error type to user. + MOVD $vector, R3; \ + MOVD R3, CPU_VECTOR_CODE(RSV_REG); \ + MRS ESR_EL1, R3; \ + MOVD R3, CPU_ERROR_CODE(RSV_REG); \ + B ·kernelExitToEl1(SB); + // storeAppASID writes the application's asid value. TEXT ·storeAppASID(SB),NOSPLIT,$0-8 MOVD asid+0(FP), R1 @@ -503,6 +520,10 @@ TEXT ·kernelExitToEl1(SB),NOSPLIT,$0 MOVD CPU_REGISTERS+PTRACE_PC(RSV_REG), R1 MSR R1, ELR_EL1 + // restore sentry's tls. + MOVD CPU_REGISTERS+PTRACE_TLS(RSV_REG), R1 + MSR R1, TPIDR_EL0 + MOVD CPU_REGISTERS+PTRACE_SP(RSV_REG), R1 MOVD R1, RSP @@ -659,21 +680,7 @@ el0_svc: el0_da: el0_ia: - WORD $0xd538d092 //MRS TPIDR_EL1, R18 - WORD $0xd538601a //MRS FAR_EL1, R26 - - MOVD R26, CPU_FAULT_ADDR(RSV_REG) - - MOVD $1, R3 - MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user. - - MOVD $PageFault, R3 - MOVD R3, CPU_VECTOR_CODE(RSV_REG) - - MRS ESR_EL1, R3 - MOVD R3, CPU_ERROR_CODE(RSV_REG) - - B ·kernelExitToEl1(SB) + EXCEPTION_WITH_ERROR(1, PageFault) el0_fpsimd_acc: B ·Shutdown(SB) @@ -688,10 +695,7 @@ el0_sp_pc: B ·Shutdown(SB) el0_undef: - MOVD $El0Sync_undef, R3 - MOVD R3, CPU_VECTOR_CODE(RSV_REG) - - B ·kernelExitToEl1(SB) + EXCEPTION_WITH_ERROR(1, El0SyncUndef) el0_dbg: B ·Shutdown(SB) @@ -707,6 +711,29 @@ TEXT ·El0_fiq(SB),NOSPLIT,$0 TEXT ·El0_error(SB),NOSPLIT,$0 KERNEL_ENTRY_FROM_EL0 + WORD $0xd5385219 // MRS ESR_EL1, R25 + AND $ESR_ELx_SERR_MASK, R25, R24 + CMP $ESR_ELx_SERR_NMI, R24 + BEQ el0_nmi + B el0_bounce +el0_nmi: + WORD $0xd538d092 //MRS TPIDR_EL1, R18 + WORD $0xd538601a //MRS FAR_EL1, R26 + + MOVD R26, CPU_FAULT_ADDR(RSV_REG) + + MOVD $1, R3 + MOVD R3, CPU_ERROR_TYPE(RSV_REG) // Set error type to user. + + MOVD $El0ErrNMI, R3 + MOVD R3, CPU_VECTOR_CODE(RSV_REG) + + MRS ESR_EL1, R3 + MOVD R3, CPU_ERROR_CODE(RSV_REG) + + B ·kernelExitToEl1(SB) + +el0_bounce: WORD $0xd538d092 //MRS TPIDR_EL1, R18 WORD $0xd538601a //MRS FAR_EL1, R26 @@ -718,7 +745,7 @@ TEXT ·El0_error(SB),NOSPLIT,$0 MOVD $VirtualizationException, R3 MOVD R3, CPU_VECTOR_CODE(RSV_REG) - B ·HaltAndResume(SB) + B ·kernelExitToEl1(SB) TEXT ·El0_sync_invalid(SB),NOSPLIT,$0 B ·Shutdown(SB) diff --git a/pkg/sentry/platform/ring0/kernel.go b/pkg/sentry/platform/ring0/kernel.go index 264be23d3..292f9d0cc 100644 --- a/pkg/sentry/platform/ring0/kernel.go +++ b/pkg/sentry/platform/ring0/kernel.go @@ -16,11 +16,9 @@ package ring0 // Init initializes a new kernel. // -// N.B. that constraints on KernelOpts must be satisfied. -// //go:nosplit -func (k *Kernel) Init(opts KernelOpts, maxCPUs int) { - k.init(opts, maxCPUs) +func (k *Kernel) Init(maxCPUs int) { + k.init(maxCPUs) } // Halt halts execution. diff --git a/pkg/sentry/platform/ring0/kernel_amd64.go b/pkg/sentry/platform/ring0/kernel_amd64.go index 3a9dff4cc..b55dc29b3 100644 --- a/pkg/sentry/platform/ring0/kernel_amd64.go +++ b/pkg/sentry/platform/ring0/kernel_amd64.go @@ -24,10 +24,7 @@ import ( ) // init initializes architecture-specific state. -func (k *Kernel) init(opts KernelOpts, maxCPUs int) { - // Save the root page tables. - k.PageTables = opts.PageTables - +func (k *Kernel) init(maxCPUs int) { entrySize := reflect.TypeOf(kernelEntry{}).Size() var ( entries []kernelEntry diff --git a/pkg/sentry/platform/ring0/kernel_arm64.go b/pkg/sentry/platform/ring0/kernel_arm64.go index b294ccc7c..6cbbf001f 100644 --- a/pkg/sentry/platform/ring0/kernel_arm64.go +++ b/pkg/sentry/platform/ring0/kernel_arm64.go @@ -25,9 +25,7 @@ func HaltAndResume() func HaltEl1SvcAndResume() // init initializes architecture-specific state. -func (k *Kernel) init(opts KernelOpts, maxCPUs int) { - // Save the root page tables. - k.PageTables = opts.PageTables +func (k *Kernel) init(maxCPUs int) { } // init initializes architecture-specific state. diff --git a/pkg/sentry/platform/ring0/offsets_arm64.go b/pkg/sentry/platform/ring0/offsets_arm64.go index 45eba960d..53bc3353c 100644 --- a/pkg/sentry/platform/ring0/offsets_arm64.go +++ b/pkg/sentry/platform/ring0/offsets_arm64.go @@ -47,43 +47,36 @@ func Emit(w io.Writer) { fmt.Fprintf(w, "#define _KERNEL_FLAGS 0x%02x\n", KernelFlagsSet) fmt.Fprintf(w, "\n// Vectors.\n") - fmt.Fprintf(w, "#define El1SyncInvalid 0x%02x\n", El1SyncInvalid) - fmt.Fprintf(w, "#define El1IrqInvalid 0x%02x\n", El1IrqInvalid) - fmt.Fprintf(w, "#define El1FiqInvalid 0x%02x\n", El1FiqInvalid) - fmt.Fprintf(w, "#define El1ErrorInvalid 0x%02x\n", El1ErrorInvalid) fmt.Fprintf(w, "#define El1Sync 0x%02x\n", El1Sync) fmt.Fprintf(w, "#define El1Irq 0x%02x\n", El1Irq) fmt.Fprintf(w, "#define El1Fiq 0x%02x\n", El1Fiq) - fmt.Fprintf(w, "#define El1Error 0x%02x\n", El1Error) + fmt.Fprintf(w, "#define El1Err 0x%02x\n", El1Err) fmt.Fprintf(w, "#define El0Sync 0x%02x\n", El0Sync) fmt.Fprintf(w, "#define El0Irq 0x%02x\n", El0Irq) fmt.Fprintf(w, "#define El0Fiq 0x%02x\n", El0Fiq) - fmt.Fprintf(w, "#define El0Error 0x%02x\n", El0Error) + fmt.Fprintf(w, "#define El0Err 0x%02x\n", El0Err) - fmt.Fprintf(w, "#define El0Sync_invalid 0x%02x\n", El0Sync_invalid) - fmt.Fprintf(w, "#define El0Irq_invalid 0x%02x\n", El0Irq_invalid) - fmt.Fprintf(w, "#define El0Fiq_invalid 0x%02x\n", El0Fiq_invalid) - fmt.Fprintf(w, "#define El0Error_invalid 0x%02x\n", El0Error_invalid) + fmt.Fprintf(w, "#define El1SyncDa 0x%02x\n", El1SyncDa) + fmt.Fprintf(w, "#define El1SyncIa 0x%02x\n", El1SyncIa) + fmt.Fprintf(w, "#define El1SyncSpPc 0x%02x\n", El1SyncSpPc) + fmt.Fprintf(w, "#define El1SyncUndef 0x%02x\n", El1SyncUndef) + fmt.Fprintf(w, "#define El1SyncDbg 0x%02x\n", El1SyncDbg) + fmt.Fprintf(w, "#define El1SyncInv 0x%02x\n", El1SyncInv) - fmt.Fprintf(w, "#define El1Sync_da 0x%02x\n", El1Sync_da) - fmt.Fprintf(w, "#define El1Sync_ia 0x%02x\n", El1Sync_ia) - fmt.Fprintf(w, "#define El1Sync_sp_pc 0x%02x\n", El1Sync_sp_pc) - fmt.Fprintf(w, "#define El1Sync_undef 0x%02x\n", El1Sync_undef) - fmt.Fprintf(w, "#define El1Sync_dbg 0x%02x\n", El1Sync_dbg) - fmt.Fprintf(w, "#define El1Sync_inv 0x%02x\n", El1Sync_inv) + fmt.Fprintf(w, "#define El0SyncSVC 0x%02x\n", El0SyncSVC) + fmt.Fprintf(w, "#define El0SyncDa 0x%02x\n", El0SyncDa) + fmt.Fprintf(w, "#define El0SyncIa 0x%02x\n", El0SyncIa) + fmt.Fprintf(w, "#define El0SyncFpsimdAcc 0x%02x\n", El0SyncFpsimdAcc) + fmt.Fprintf(w, "#define El0SyncSveAcc 0x%02x\n", El0SyncSveAcc) + fmt.Fprintf(w, "#define El0SyncSys 0x%02x\n", El0SyncSys) + fmt.Fprintf(w, "#define El0SyncSpPc 0x%02x\n", El0SyncSpPc) + fmt.Fprintf(w, "#define El0SyncUndef 0x%02x\n", El0SyncUndef) + fmt.Fprintf(w, "#define El0SyncDbg 0x%02x\n", El0SyncDbg) + fmt.Fprintf(w, "#define El0SyncInv 0x%02x\n", El0SyncInv) - fmt.Fprintf(w, "#define El0Sync_svc 0x%02x\n", El0Sync_svc) - fmt.Fprintf(w, "#define El0Sync_da 0x%02x\n", El0Sync_da) - fmt.Fprintf(w, "#define El0Sync_ia 0x%02x\n", El0Sync_ia) - fmt.Fprintf(w, "#define El0Sync_fpsimd_acc 0x%02x\n", El0Sync_fpsimd_acc) - fmt.Fprintf(w, "#define El0Sync_sve_acc 0x%02x\n", El0Sync_sve_acc) - fmt.Fprintf(w, "#define El0Sync_sys 0x%02x\n", El0Sync_sys) - fmt.Fprintf(w, "#define El0Sync_sp_pc 0x%02x\n", El0Sync_sp_pc) - fmt.Fprintf(w, "#define El0Sync_undef 0x%02x\n", El0Sync_undef) - fmt.Fprintf(w, "#define El0Sync_dbg 0x%02x\n", El0Sync_dbg) - fmt.Fprintf(w, "#define El0Sync_inv 0x%02x\n", El0Sync_inv) + fmt.Fprintf(w, "#define El0ErrNMI 0x%02x\n", El0ErrNMI) fmt.Fprintf(w, "#define PageFault 0x%02x\n", PageFault) fmt.Fprintf(w, "#define Syscall 0x%02x\n", Syscall) diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables.go b/pkg/sentry/platform/ring0/pagetables/pagetables.go index 7f18ac296..bc16a1622 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables.go @@ -30,6 +30,10 @@ type PageTables struct { Allocator Allocator // root is the pagetable root. + // + // For same archs such as amd64, the upper of the PTEs is cloned + // from and owned by upperSharedPageTables which are shared among + // many PageTables if upperSharedPageTables is not nil. root *PTEs // rootPhysical is the cached physical address of the root. @@ -39,15 +43,52 @@ type PageTables struct { // archPageTables includes architecture-specific features. archPageTables + + // upperSharedPageTables represents a read-only shared upper + // of the Pagetable. When it is not nil, the upper is not + // allowed to be modified. + upperSharedPageTables *PageTables + + // upperStart is the start address of the upper portion that + // are shared from upperSharedPageTables + upperStart uintptr + + // readOnlyShared indicates the Pagetables are read-only and + // own the ranges that are shared with other Pagetables. + readOnlyShared bool } -// New returns new PageTables. -func New(a Allocator) *PageTables { +// NewWithUpper returns new PageTables. +// +// upperSharedPageTables are used for mapping the upper of addresses, +// starting at upperStart. These pageTables should not be touched (as +// invalidations may be incorrect) after they are passed as an +// upperSharedPageTables. Only when all dependent PageTables are gone +// may they be used. The intenteded use case is for kernel page tables, +// which are static and fixed. +// +// Precondition: upperStart must be between canonical ranges. +// Precondition: upperStart must be pgdSize aligned. +// precondition: upperSharedPageTables must be marked read-only shared. +func NewWithUpper(a Allocator, upperSharedPageTables *PageTables, upperStart uintptr) *PageTables { p := new(PageTables) p.Init(a) + if upperSharedPageTables != nil { + if !upperSharedPageTables.readOnlyShared { + panic("Only read-only shared pagetables can be used as upper") + } + p.upperSharedPageTables = upperSharedPageTables + p.upperStart = upperStart + p.cloneUpperShared() + } return p } +// New returns new PageTables. +func New(a Allocator) *PageTables { + return NewWithUpper(a, nil, 0) +} + // mapVisitor is used for map. type mapVisitor struct { target uintptr // Input. @@ -90,6 +131,21 @@ func (*mapVisitor) requiresSplit() bool { return true } // //go:nosplit func (p *PageTables) Map(addr usermem.Addr, length uintptr, opts MapOpts, physical uintptr) bool { + if p.readOnlyShared { + panic("Should not modify read-only shared pagetables.") + } + if uintptr(addr)+length < uintptr(addr) { + panic("addr & length overflow") + } + if p.upperSharedPageTables != nil { + // ignore change to the read-only upper shared portion. + if uintptr(addr) >= p.upperStart { + return false + } + if uintptr(addr)+length > p.upperStart { + length = p.upperStart - uintptr(addr) + } + } if !opts.AccessType.Any() { return p.Unmap(addr, length) } @@ -128,12 +184,27 @@ func (v *unmapVisitor) visit(start uintptr, pte *PTE, align uintptr) { // // True is returned iff there was a previous mapping in the range. // -// Precondition: addr & length must be page-aligned. +// Precondition: addr & length must be page-aligned, their sum must not overflow. // // +checkescape:hard,stack // //go:nosplit func (p *PageTables) Unmap(addr usermem.Addr, length uintptr) bool { + if p.readOnlyShared { + panic("Should not modify read-only shared pagetables.") + } + if uintptr(addr)+length < uintptr(addr) { + panic("addr & length overflow") + } + if p.upperSharedPageTables != nil { + // ignore change to the read-only upper shared portion. + if uintptr(addr) >= p.upperStart { + return false + } + if uintptr(addr)+length > p.upperStart { + length = p.upperStart - uintptr(addr) + } + } w := unmapWalker{ pageTables: p, visitor: unmapVisitor{ @@ -218,3 +289,10 @@ func (p *PageTables) Lookup(addr usermem.Addr) (physical uintptr, opts MapOpts) w.iterateRange(uintptr(addr), uintptr(addr)+1) return w.visitor.physical + offset, w.visitor.opts } + +// MarkReadOnlyShared marks the pagetables read-only and can be shared. +// +// It is usually used on the pagetables that are used as the upper +func (p *PageTables) MarkReadOnlyShared() { + p.readOnlyShared = true +} diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go index 520161755..a4e416af7 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_aarch64.go @@ -24,14 +24,6 @@ import ( // archPageTables is architecture-specific data. type archPageTables struct { - // root is the pagetable root for kernel space. - root *PTEs - - // rootPhysical is the cached physical address of the root. - // - // This is saved only to prevent constant translation. - rootPhysical uintptr - asid uint16 } @@ -46,7 +38,7 @@ func (p *PageTables) TTBR0_EL1(noFlush bool, asid uint16) uint64 { // //go:nosplit func (p *PageTables) TTBR1_EL1(noFlush bool, asid uint16) uint64 { - return uint64(p.archPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset + return uint64(p.upperSharedPageTables.rootPhysical) | (uint64(asid)&ttbrASIDMask)<<ttbrASIDOffset } // Bits in page table entries. diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go index 0c153cf8c..e7ab887e5 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_amd64.go @@ -50,5 +50,26 @@ func (p *PageTables) Init(allocator Allocator) { p.rootPhysical = p.Allocator.PhysicalFor(p.root) } +func pgdIndex(upperStart uintptr) uintptr { + if upperStart&(pgdSize-1) != 0 { + panic("upperStart should be pgd size aligned") + } + if upperStart >= upperBottom { + return entriesPerPage/2 + (upperStart-upperBottom)/pgdSize + } + if upperStart < lowerTop { + return upperStart / pgdSize + } + panic("upperStart should be in canonical range") +} + +// cloneUpperShared clone the upper from the upper shared page tables. +// +//go:nosplit +func (p *PageTables) cloneUpperShared() { + start := pgdIndex(p.upperStart) + copy(p.root[start:entriesPerPage], p.upperSharedPageTables.root[start:entriesPerPage]) +} + // PTEs is a collection of entries. type PTEs [entriesPerPage]PTE diff --git a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go index 1a49f12a2..5392bf27a 100644 --- a/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go +++ b/pkg/sentry/platform/ring0/pagetables/pagetables_arm64.go @@ -36,7 +36,7 @@ const ( pudSize = 1 << pudShift pgdSize = 1 << pgdShift - ttbrASIDOffset = 55 + ttbrASIDOffset = 48 ttbrASIDMask = 0xff entriesPerPage = 512 @@ -49,8 +49,17 @@ func (p *PageTables) Init(allocator Allocator) { p.Allocator = allocator p.root = p.Allocator.NewPTEs() p.rootPhysical = p.Allocator.PhysicalFor(p.root) - p.archPageTables.root = p.Allocator.NewPTEs() - p.archPageTables.rootPhysical = p.Allocator.PhysicalFor(p.archPageTables.root) +} + +// cloneUpperShared clone the upper from the upper shared page tables. +// +//go:nosplit +func (p *PageTables) cloneUpperShared() { + if p.upperStart != upperBottom { + panic("upperStart should be the same as upperBottom") + } + + // nothing to do for arm. } // PTEs is a collection of entries. diff --git a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go index c261d393a..157c9a7cc 100644 --- a/pkg/sentry/platform/ring0/pagetables/walker_arm64.go +++ b/pkg/sentry/platform/ring0/pagetables/walker_arm64.go @@ -116,7 +116,7 @@ func next(start uintptr, size uintptr) uintptr { func (w *Walker) iterateRangeCanonical(start, end uintptr) { pgdEntryIndex := w.pageTables.root if start >= upperBottom { - pgdEntryIndex = w.pageTables.archPageTables.root + pgdEntryIndex = w.pageTables.upperSharedPageTables.root } for pgdIndex := (uint16((start & pgdMask) >> pgdShift)); start < end && pgdIndex < entriesPerPage; pgdIndex++ { diff --git a/pkg/sentry/socket/control/control_vfs2.go b/pkg/sentry/socket/control/control_vfs2.go index d9621968c..37d02948f 100644 --- a/pkg/sentry/socket/control/control_vfs2.go +++ b/pkg/sentry/socket/control/control_vfs2.go @@ -24,6 +24,8 @@ import ( ) // SCMRightsVFS2 represents a SCM_RIGHTS socket control message. +// +// +stateify savable type SCMRightsVFS2 interface { transport.RightsControlMessage @@ -34,9 +36,11 @@ type SCMRightsVFS2 interface { Files(ctx context.Context, max int) (rf RightsFilesVFS2, truncated bool) } -// RightsFiles represents a SCM_RIGHTS socket control message. A reference is -// maintained for each vfs.FileDescription and is release either when an FD is created or -// when the Release method is called. +// RightsFilesVFS2 represents a SCM_RIGHTS socket control message. A reference +// is maintained for each vfs.FileDescription and is release either when an FD +// is created or when the Release method is called. +// +// +stateify savable type RightsFilesVFS2 []*vfs.FileDescription // NewSCMRightsVFS2 creates a new SCM_RIGHTS socket control message diff --git a/pkg/sentry/socket/hostinet/socket_vfs2.go b/pkg/sentry/socket/hostinet/socket_vfs2.go index 163af329b..9a2cac40b 100644 --- a/pkg/sentry/socket/hostinet/socket_vfs2.go +++ b/pkg/sentry/socket/hostinet/socket_vfs2.go @@ -33,6 +33,7 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) +// +stateify savable type socketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -51,7 +52,7 @@ var _ = socket.SocketVFS2(&socketVFS2{}) func newVFS2Socket(t *kernel.Task, family int, stype linux.SockType, protocol int, fd int, flags uint32) (*vfs.FileDescription, *syserr.Error) { mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) s := &socketVFS2{ diff --git a/pkg/sentry/socket/hostinet/stack.go b/pkg/sentry/socket/hostinet/stack.go index faa61160e..7e7857ac3 100644 --- a/pkg/sentry/socket/hostinet/stack.go +++ b/pkg/sentry/socket/hostinet/stack.go @@ -324,7 +324,12 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { } // AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. -func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { +func (s *Stack) AddInterfaceAddr(int32, inet.InterfaceAddr) error { + return syserror.EACCES +} + +// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr. +func (s *Stack) RemoveInterfaceAddr(int32, inet.InterfaceAddr) error { return syserror.EACCES } @@ -359,7 +364,7 @@ func (s *Stack) TCPSACKEnabled() (bool, error) { } // SetTCPSACKEnabled implements inet.Stack.SetTCPSACKEnabled. -func (s *Stack) SetTCPSACKEnabled(enabled bool) error { +func (s *Stack) SetTCPSACKEnabled(bool) error { return syserror.EACCES } @@ -369,7 +374,7 @@ func (s *Stack) TCPRecovery() (inet.TCPLossRecovery, error) { } // SetTCPRecovery implements inet.Stack.SetTCPRecovery. -func (s *Stack) SetTCPRecovery(recovery inet.TCPLossRecovery) error { +func (s *Stack) SetTCPRecovery(inet.TCPLossRecovery) error { return syserror.EACCES } @@ -430,18 +435,18 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { } if rawLine == "" { - return fmt.Errorf("Failed to get raw line") + return fmt.Errorf("failed to get raw line") } parts := strings.SplitN(rawLine, ":", 2) if len(parts) != 2 { - return fmt.Errorf("Failed to get prefix from: %q", rawLine) + return fmt.Errorf("failed to get prefix from: %q", rawLine) } sliceStat = toSlice(stat) fields := strings.Fields(strings.TrimSpace(parts[1])) if len(fields) != len(sliceStat) { - return fmt.Errorf("Failed to parse fields: %q", rawLine) + return fmt.Errorf("failed to parse fields: %q", rawLine) } if _, ok := stat.(*inet.StatSNMPTCP); ok { snmpTCP = true @@ -457,7 +462,7 @@ func (s *Stack) Statistics(stat interface{}, arg string) error { sliceStat[i], err = strconv.ParseUint(fields[i], 10, 64) } if err != nil { - return fmt.Errorf("Failed to parse field %d from: %q, %v", i, rawLine, err) + return fmt.Errorf("failed to parse field %d from: %q, %v", i, rawLine, err) } } @@ -495,6 +500,6 @@ func (s *Stack) Forwarding(protocol tcpip.NetworkProtocolNumber) bool { } // SetForwarding implements inet.Stack.SetForwarding. -func (s *Stack) SetForwarding(protocol tcpip.NetworkProtocolNumber, enable bool) error { +func (s *Stack) SetForwarding(tcpip.NetworkProtocolNumber, bool) error { return syserror.EACCES } diff --git a/pkg/sentry/socket/netfilter/extensions.go b/pkg/sentry/socket/netfilter/extensions.go index 549787955..e0976fed0 100644 --- a/pkg/sentry/socket/netfilter/extensions.go +++ b/pkg/sentry/socket/netfilter/extensions.go @@ -100,24 +100,43 @@ func unmarshalMatcher(match linux.XTEntryMatch, filter stack.IPHeaderFilter, buf // marshalTarget and unmarshalTarget can be used. type targetMaker interface { // id uniquely identifies the target. - id() stack.TargetID + id() targetID - // marshal converts from a stack.Target to an ABI struct. - marshal(target stack.Target) []byte + // marshal converts from a target to an ABI struct. + marshal(target target) []byte - // unmarshal converts from the ABI matcher struct to a stack.Target. - unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) + // unmarshal converts from the ABI matcher struct to a target. + unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) } -// targetMakers maps the TargetID of supported targets to the targetMaker that +// A targetID uniquely identifies a target. +type targetID struct { + // name is the target name as stored in the xt_entry_target struct. + name string + + // networkProtocol is the protocol to which the target applies. + networkProtocol tcpip.NetworkProtocolNumber + + // revision is the version of the target. + revision uint8 +} + +// target extends a stack.Target, allowing it to be used with the extension +// system. The sentry only uses targets, never stack.Targets directly. +type target interface { + stack.Target + id() targetID +} + +// targetMakers maps the targetID of supported targets to the targetMaker that // marshals and unmarshals it. It is immutable after package initialization. -var targetMakers = map[stack.TargetID]targetMaker{} +var targetMakers = map[targetID]targetMaker{} func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8) (uint8, bool) { - tid := stack.TargetID{ - Name: name, - NetworkProtocol: netProto, - Revision: rev, + tid := targetID{ + name: name, + networkProtocol: netProto, + revision: rev, } if _, ok := targetMakers[tid]; !ok { return 0, false @@ -126,8 +145,8 @@ func targetRevision(name string, netProto tcpip.NetworkProtocolNumber, rev uint8 // Return the highest supported revision unless rev is higher. for _, other := range targetMakers { otherID := other.id() - if name == otherID.Name && netProto == otherID.NetworkProtocol && otherID.Revision > rev { - rev = uint8(otherID.Revision) + if name == otherID.name && netProto == otherID.networkProtocol && otherID.revision > rev { + rev = uint8(otherID.revision) } } return rev, true @@ -142,19 +161,21 @@ func registerTargetMaker(tm targetMaker) { targetMakers[tm.id()] = tm } -func marshalTarget(target stack.Target) []byte { - targetMaker, ok := targetMakers[target.ID()] +func marshalTarget(tgt stack.Target) []byte { + // The sentry only uses targets, never stack.Targets directly. + target := tgt.(target) + targetMaker, ok := targetMakers[target.id()] if !ok { - panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.ID())) + panic(fmt.Sprintf("unknown target of type %T with id %+v.", target, target.id())) } return targetMaker.marshal(target) } -func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (stack.Target, *syserr.Error) { - tid := stack.TargetID{ - Name: target.Name.String(), - NetworkProtocol: filter.NetworkProtocol(), - Revision: target.Revision, +func unmarshalTarget(target linux.XTEntryTarget, filter stack.IPHeaderFilter, buf []byte) (target, *syserr.Error) { + tid := targetID{ + name: target.Name.String(), + networkProtocol: filter.NetworkProtocol(), + revision: target.Revision, } targetMaker, ok := targetMakers[tid] if !ok { diff --git a/pkg/sentry/socket/netfilter/ipv4.go b/pkg/sentry/socket/netfilter/ipv4.go index b560fae0d..70c561cce 100644 --- a/pkg/sentry/socket/netfilter/ipv4.go +++ b/pkg/sentry/socket/netfilter/ipv4.go @@ -46,13 +46,13 @@ func convertNetstackToBinary4(stk *stack.Stack, tablename linux.TableName) (linu return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) } - table, ok := stk.IPTables().GetTable(tablename.String(), false) + id, ok := nameToID[tablename.String()] if !ok { return linux.KernelIPTGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) } // Setup the info struct. - entries, info := getEntries4(table, tablename) + entries, info := getEntries4(stk.IPTables().GetTable(id, false), tablename) return entries, info, nil } diff --git a/pkg/sentry/socket/netfilter/ipv6.go b/pkg/sentry/socket/netfilter/ipv6.go index 4253f7bf4..5dbb604f0 100644 --- a/pkg/sentry/socket/netfilter/ipv6.go +++ b/pkg/sentry/socket/netfilter/ipv6.go @@ -46,13 +46,13 @@ func convertNetstackToBinary6(stk *stack.Stack, tablename linux.TableName) (linu return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("table name %q too long", tablename) } - table, ok := stk.IPTables().GetTable(tablename.String(), true) + id, ok := nameToID[tablename.String()] if !ok { return linux.KernelIP6TGetEntries{}, linux.IPTGetinfo{}, fmt.Errorf("couldn't find table %q", tablename) } // Setup the info struct, which is the same in IPv4 and IPv6. - entries, info := getEntries6(table, tablename) + entries, info := getEntries6(stk.IPTables().GetTable(id, true), tablename) return entries, info, nil } diff --git a/pkg/sentry/socket/netfilter/netfilter.go b/pkg/sentry/socket/netfilter/netfilter.go index 904a12e38..b283d7229 100644 --- a/pkg/sentry/socket/netfilter/netfilter.go +++ b/pkg/sentry/socket/netfilter/netfilter.go @@ -42,6 +42,45 @@ func nflog(format string, args ...interface{}) { } } +// Table names. +const ( + natTable = "nat" + mangleTable = "mangle" + filterTable = "filter" +) + +// nameToID is immutable. +var nameToID = map[string]stack.TableID{ + natTable: stack.NATID, + mangleTable: stack.MangleID, + filterTable: stack.FilterID, +} + +// DefaultLinuxTables returns the rules of stack.DefaultTables() wrapped for +// compatibility with netfilter extensions. +func DefaultLinuxTables() *stack.IPTables { + tables := stack.DefaultTables() + tables.VisitTargets(func(oldTarget stack.Target) stack.Target { + switch val := oldTarget.(type) { + case *stack.AcceptTarget: + return &acceptTarget{AcceptTarget: *val} + case *stack.DropTarget: + return &dropTarget{DropTarget: *val} + case *stack.ErrorTarget: + return &errorTarget{ErrorTarget: *val} + case *stack.UserChainTarget: + return &userChainTarget{UserChainTarget: *val} + case *stack.ReturnTarget: + return &returnTarget{ReturnTarget: *val} + case *stack.RedirectTarget: + return &redirectTarget{RedirectTarget: *val} + default: + panic(fmt.Sprintf("Unknown rule in default iptables of type %T", val)) + } + }) + return tables +} + // GetInfo returns information about iptables. func GetInfo(t *kernel.Task, stack *stack.Stack, outPtr usermem.Addr, ipv6 bool) (linux.IPTGetinfo, *syserr.Error) { // Read in the struct and table name. @@ -144,9 +183,9 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { // TODO(gvisor.dev/issue/170): Support other tables. var table stack.Table switch replace.Name.String() { - case stack.FilterTable: + case filterTable: table = stack.EmptyFilterTable() - case stack.NATTable: + case natTable: table = stack.EmptyNATTable() default: nflog("we don't yet support writing to the %q table (gvisor.dev/issue/170)", replace.Name.String()) @@ -177,7 +216,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { } if offset == replace.Underflow[hook] { if !validUnderflow(table.Rules[ruleIdx], ipv6) { - nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP", ruleIdx) + nflog("underflow for hook %d isn't an unconditional ACCEPT or DROP: %+v", ruleIdx) return syserr.ErrInvalidArgument } table.Underflows[hk] = ruleIdx @@ -253,8 +292,7 @@ func SetEntries(stk *stack.Stack, optVal []byte, ipv6 bool) *syserr.Error { // - There are no chains without an unconditional final rule. // - There are no chains without an unconditional underflow rule. - return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(replace.Name.String(), table, ipv6)) - + return syserr.TranslateNetstackError(stk.IPTables().ReplaceTable(nameToID[replace.Name.String()], table, ipv6)) } // parseMatchers parses 0 or more matchers from optVal. optVal should contain @@ -308,7 +346,7 @@ func validUnderflow(rule stack.Rule, ipv6 bool) bool { return false } switch rule.Target.(type) { - case *stack.AcceptTarget, *stack.DropTarget: + case *acceptTarget, *dropTarget: return true default: return false @@ -319,7 +357,7 @@ func isUnconditionalAccept(rule stack.Rule, ipv6 bool) bool { if !validUnderflow(rule, ipv6) { return false } - _, ok := rule.Target.(*stack.AcceptTarget) + _, ok := rule.Target.(*acceptTarget) return ok } diff --git a/pkg/sentry/socket/netfilter/targets.go b/pkg/sentry/socket/netfilter/targets.go index 0e14447fe..f2653d523 100644 --- a/pkg/sentry/socket/netfilter/targets.go +++ b/pkg/sentry/socket/netfilter/targets.go @@ -26,6 +26,15 @@ import ( "gvisor.dev/gvisor/pkg/usermem" ) +// ErrorTargetName is used to mark targets as error targets. Error targets +// shouldn't be reached - an error has occurred if we fall through to one. +const ErrorTargetName = "ERROR" + +// RedirectTargetName is used to mark targets as redirect targets. Redirect +// targets should be reached for only NAT and Mangle tables. These targets will +// change the destination port and/or IP for packets. +const RedirectTargetName = "REDIRECT" + func init() { // Standard targets include ACCEPT, DROP, RETURN, and JUMP. registerTargetMaker(&standardTargetMaker{ @@ -52,25 +61,96 @@ func init() { }) } +// The stack package provides some basic, useful targets for us. The following +// types wrap them for compatibility with the extension system. + +type acceptTarget struct { + stack.AcceptTarget +} + +func (at *acceptTarget) id() targetID { + return targetID{ + networkProtocol: at.NetworkProtocol, + } +} + +type dropTarget struct { + stack.DropTarget +} + +func (dt *dropTarget) id() targetID { + return targetID{ + networkProtocol: dt.NetworkProtocol, + } +} + +type errorTarget struct { + stack.ErrorTarget +} + +func (et *errorTarget) id() targetID { + return targetID{ + name: ErrorTargetName, + networkProtocol: et.NetworkProtocol, + } +} + +type userChainTarget struct { + stack.UserChainTarget +} + +func (uc *userChainTarget) id() targetID { + return targetID{ + name: ErrorTargetName, + networkProtocol: uc.NetworkProtocol, + } +} + +type returnTarget struct { + stack.ReturnTarget +} + +func (rt *returnTarget) id() targetID { + return targetID{ + networkProtocol: rt.NetworkProtocol, + } +} + +type redirectTarget struct { + stack.RedirectTarget + + // addr must be (un)marshalled when reading and writing the target to + // userspace, but does not affect behavior. + addr tcpip.Address +} + +func (rt *redirectTarget) id() targetID { + return targetID{ + name: RedirectTargetName, + networkProtocol: rt.NetworkProtocol, + } +} + type standardTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func (sm *standardTargetMaker) id() stack.TargetID { +func (sm *standardTargetMaker) id() targetID { // Standard targets have the empty string as a name and no revisions. - return stack.TargetID{ - NetworkProtocol: sm.NetworkProtocol, + return targetID{ + networkProtocol: sm.NetworkProtocol, } } -func (*standardTargetMaker) marshal(target stack.Target) []byte { + +func (*standardTargetMaker) marshal(target target) []byte { // Translate verdicts the same way as the iptables tool. var verdict int32 switch tg := target.(type) { - case *stack.AcceptTarget: + case *acceptTarget: verdict = -linux.NF_ACCEPT - 1 - case *stack.DropTarget: + case *dropTarget: verdict = -linux.NF_DROP - 1 - case *stack.ReturnTarget: + case *returnTarget: verdict = linux.NF_RETURN case *JumpTarget: verdict = int32(tg.Offset) @@ -90,7 +170,7 @@ func (*standardTargetMaker) marshal(target stack.Target) []byte { return binary.Marshal(ret, usermem.ByteOrder, xt) } -func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { +func (*standardTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { if len(buf) != linux.SizeOfXTStandardTarget { nflog("buf has wrong size for standard target %d", len(buf)) return nil, syserr.ErrInvalidArgument @@ -114,20 +194,20 @@ type errorTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func (em *errorTargetMaker) id() stack.TargetID { +func (em *errorTargetMaker) id() targetID { // Error targets have no revision. - return stack.TargetID{ - Name: stack.ErrorTargetName, - NetworkProtocol: em.NetworkProtocol, + return targetID{ + name: ErrorTargetName, + networkProtocol: em.NetworkProtocol, } } -func (*errorTargetMaker) marshal(target stack.Target) []byte { +func (*errorTargetMaker) marshal(target target) []byte { var errorName string switch tg := target.(type) { - case *stack.ErrorTarget: - errorName = stack.ErrorTargetName - case *stack.UserChainTarget: + case *errorTarget: + errorName = ErrorTargetName + case *userChainTarget: errorName = tg.Name default: panic(fmt.Sprintf("errorMakerTarget cannot marshal unknown type %T", target)) @@ -140,37 +220,38 @@ func (*errorTargetMaker) marshal(target stack.Target) []byte { }, } copy(xt.Name[:], errorName) - copy(xt.Target.Name[:], stack.ErrorTargetName) + copy(xt.Target.Name[:], ErrorTargetName) ret := make([]byte, 0, linux.SizeOfXTErrorTarget) return binary.Marshal(ret, usermem.ByteOrder, xt) } -func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { +func (*errorTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { if len(buf) != linux.SizeOfXTErrorTarget { nflog("buf has insufficient size for error target %d", len(buf)) return nil, syserr.ErrInvalidArgument } - var errorTarget linux.XTErrorTarget + var errTgt linux.XTErrorTarget buf = buf[:linux.SizeOfXTErrorTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &errorTarget) + binary.Unmarshal(buf, usermem.ByteOrder, &errTgt) // Error targets are used in 2 cases: - // * An actual error case. These rules have an error - // named stack.ErrorTargetName. The last entry of the table - // is usually an error case to catch any packets that - // somehow fall through every rule. + // * An actual error case. These rules have an error named + // ErrorTargetName. The last entry of the table is usually an error + // case to catch any packets that somehow fall through every rule. // * To mark the start of a user defined chain. These // rules have an error with the name of the chain. - switch name := errorTarget.Name.String(); name { - case stack.ErrorTargetName: - return &stack.ErrorTarget{NetworkProtocol: filter.NetworkProtocol()}, nil + switch name := errTgt.Name.String(); name { + case ErrorTargetName: + return &errorTarget{stack.ErrorTarget{ + NetworkProtocol: filter.NetworkProtocol(), + }}, nil default: // User defined chain. - return &stack.UserChainTarget{ + return &userChainTarget{stack.UserChainTarget{ Name: name, NetworkProtocol: filter.NetworkProtocol(), - }, nil + }}, nil } } @@ -178,22 +259,22 @@ type redirectTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func (rm *redirectTargetMaker) id() stack.TargetID { - return stack.TargetID{ - Name: stack.RedirectTargetName, - NetworkProtocol: rm.NetworkProtocol, +func (rm *redirectTargetMaker) id() targetID { + return targetID{ + name: RedirectTargetName, + networkProtocol: rm.NetworkProtocol, } } -func (*redirectTargetMaker) marshal(target stack.Target) []byte { - rt := target.(*stack.RedirectTarget) +func (*redirectTargetMaker) marshal(target target) []byte { + rt := target.(*redirectTarget) // This is a redirect target named redirect xt := linux.XTRedirectTarget{ Target: linux.XTEntryTarget{ TargetSize: linux.SizeOfXTRedirectTarget, }, } - copy(xt.Target.Name[:], stack.RedirectTargetName) + copy(xt.Target.Name[:], RedirectTargetName) ret := make([]byte, 0, linux.SizeOfXTRedirectTarget) xt.NfRange.RangeSize = 1 @@ -203,7 +284,7 @@ func (*redirectTargetMaker) marshal(target stack.Target) []byte { return binary.Marshal(ret, usermem.ByteOrder, xt) } -func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { +func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { if len(buf) < linux.SizeOfXTRedirectTarget { nflog("redirectTargetMaker: buf has insufficient size for redirect target %d", len(buf)) return nil, syserr.ErrInvalidArgument @@ -214,15 +295,17 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } - var redirectTarget linux.XTRedirectTarget + var rt linux.XTRedirectTarget buf = buf[:linux.SizeOfXTRedirectTarget] - binary.Unmarshal(buf, usermem.ByteOrder, &redirectTarget) + binary.Unmarshal(buf, usermem.ByteOrder, &rt) // Copy linux.XTRedirectTarget to stack.RedirectTarget. - target := stack.RedirectTarget{NetworkProtocol: filter.NetworkProtocol()} + target := redirectTarget{RedirectTarget: stack.RedirectTarget{ + NetworkProtocol: filter.NetworkProtocol(), + }} // RangeSize should be 1. - nfRange := redirectTarget.NfRange + nfRange := rt.NfRange if nfRange.RangeSize != 1 { nflog("redirectTargetMaker: bad rangesize %d", nfRange.RangeSize) return nil, syserr.ErrInvalidArgument @@ -247,7 +330,7 @@ func (*redirectTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) ( return nil, syserr.ErrInvalidArgument } - target.Addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) + target.addr = tcpip.Address(nfRange.RangeIPV4.MinIP[:]) target.Port = ntohs(nfRange.RangeIPV4.MinPort) return &target, nil @@ -264,15 +347,15 @@ type nfNATTargetMaker struct { NetworkProtocol tcpip.NetworkProtocolNumber } -func (rm *nfNATTargetMaker) id() stack.TargetID { - return stack.TargetID{ - Name: stack.RedirectTargetName, - NetworkProtocol: rm.NetworkProtocol, +func (rm *nfNATTargetMaker) id() targetID { + return targetID{ + name: RedirectTargetName, + networkProtocol: rm.NetworkProtocol, } } -func (*nfNATTargetMaker) marshal(target stack.Target) []byte { - rt := target.(*stack.RedirectTarget) +func (*nfNATTargetMaker) marshal(target target) []byte { + rt := target.(*redirectTarget) nt := nfNATTarget{ Target: linux.XTEntryTarget{ TargetSize: nfNATMarhsalledSize, @@ -281,9 +364,9 @@ func (*nfNATTargetMaker) marshal(target stack.Target) []byte { Flags: linux.NF_NAT_RANGE_PROTO_SPECIFIED, }, } - copy(nt.Target.Name[:], stack.RedirectTargetName) - copy(nt.Range.MinAddr[:], rt.Addr) - copy(nt.Range.MaxAddr[:], rt.Addr) + copy(nt.Target.Name[:], RedirectTargetName) + copy(nt.Range.MinAddr[:], rt.addr) + copy(nt.Range.MaxAddr[:], rt.addr) nt.Range.MinProto = htons(rt.Port) nt.Range.MaxProto = nt.Range.MinProto @@ -292,7 +375,7 @@ func (*nfNATTargetMaker) marshal(target stack.Target) []byte { return binary.Marshal(ret, usermem.ByteOrder, nt) } -func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Target, *syserr.Error) { +func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (target, *syserr.Error) { if size := nfNATMarhsalledSize; len(buf) < size { nflog("nfNATTargetMaker: buf has insufficient size (%d) for nfNAT target (%d)", len(buf), size) return nil, syserr.ErrInvalidArgument @@ -324,10 +407,12 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (sta return nil, syserr.ErrInvalidArgument } - target := stack.RedirectTarget{ - NetworkProtocol: filter.NetworkProtocol(), - Addr: tcpip.Address(natRange.MinAddr[:]), - Port: ntohs(natRange.MinProto), + target := redirectTarget{ + RedirectTarget: stack.RedirectTarget{ + NetworkProtocol: filter.NetworkProtocol(), + Port: ntohs(natRange.MinProto), + }, + addr: tcpip.Address(natRange.MinAddr[:]), } return &target, nil @@ -335,18 +420,24 @@ func (*nfNATTargetMaker) unmarshal(buf []byte, filter stack.IPHeaderFilter) (sta // translateToStandardTarget translates from the value in a // linux.XTStandardTarget to an stack.Verdict. -func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (stack.Target, *syserr.Error) { +func translateToStandardTarget(val int32, netProto tcpip.NetworkProtocolNumber) (target, *syserr.Error) { // TODO(gvisor.dev/issue/170): Support other verdicts. switch val { case -linux.NF_ACCEPT - 1: - return &stack.AcceptTarget{NetworkProtocol: netProto}, nil + return &acceptTarget{stack.AcceptTarget{ + NetworkProtocol: netProto, + }}, nil case -linux.NF_DROP - 1: - return &stack.DropTarget{NetworkProtocol: netProto}, nil + return &dropTarget{stack.DropTarget{ + NetworkProtocol: netProto, + }}, nil case -linux.NF_QUEUE - 1: nflog("unsupported iptables verdict QUEUE") return nil, syserr.ErrInvalidArgument case linux.NF_RETURN: - return &stack.ReturnTarget{NetworkProtocol: netProto}, nil + return &returnTarget{stack.ReturnTarget{ + NetworkProtocol: netProto, + }}, nil default: nflog("unknown iptables verdict %d", val) return nil, syserr.ErrInvalidArgument @@ -382,9 +473,9 @@ type JumpTarget struct { } // ID implements Target.ID. -func (jt *JumpTarget) ID() stack.TargetID { - return stack.TargetID{ - NetworkProtocol: jt.NetworkProtocol, +func (jt *JumpTarget) id() targetID { + return targetID{ + networkProtocol: jt.NetworkProtocol, } } diff --git a/pkg/sentry/socket/netfilter/tcp_matcher.go b/pkg/sentry/socket/netfilter/tcp_matcher.go index 844acfede..352c51390 100644 --- a/pkg/sentry/socket/netfilter/tcp_matcher.go +++ b/pkg/sentry/socket/netfilter/tcp_matcher.go @@ -71,7 +71,7 @@ func (tcpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma } if filter.Protocol != header.TCPProtocolNumber { - return nil, fmt.Errorf("TCP matching is only valid for protocol %d.", header.TCPProtocolNumber) + return nil, fmt.Errorf("TCP matching is only valid for protocol %d", header.TCPProtocolNumber) } return &TCPMatcher{ diff --git a/pkg/sentry/socket/netfilter/udp_matcher.go b/pkg/sentry/socket/netfilter/udp_matcher.go index 63201201c..c88d8268d 100644 --- a/pkg/sentry/socket/netfilter/udp_matcher.go +++ b/pkg/sentry/socket/netfilter/udp_matcher.go @@ -68,7 +68,7 @@ func (udpMarshaler) unmarshal(buf []byte, filter stack.IPHeaderFilter) (stack.Ma } if filter.Protocol != header.UDPProtocolNumber { - return nil, fmt.Errorf("UDP matching is only valid for protocol %d.", header.UDPProtocolNumber) + return nil, fmt.Errorf("UDP matching is only valid for protocol %d", header.UDPProtocolNumber) } return &UDPMatcher{ diff --git a/pkg/sentry/socket/netlink/provider_vfs2.go b/pkg/sentry/socket/netlink/provider_vfs2.go index e8930f031..f061c5d62 100644 --- a/pkg/sentry/socket/netlink/provider_vfs2.go +++ b/pkg/sentry/socket/netlink/provider_vfs2.go @@ -51,7 +51,7 @@ func (*socketProviderVFS2) Socket(t *kernel.Task, stype linux.SockType, protocol vfsfd := &s.vfsfd mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) if err := vfsfd.Init(s, linux.O_RDWR, mnt, d, &vfs.FileDescriptionOptions{ DenyPRead: true, diff --git a/pkg/sentry/socket/netlink/route/protocol.go b/pkg/sentry/socket/netlink/route/protocol.go index c84d8bd7c..f4d034c13 100644 --- a/pkg/sentry/socket/netlink/route/protocol.go +++ b/pkg/sentry/socket/netlink/route/protocol.go @@ -36,9 +36,9 @@ type commandKind int const ( kindNew commandKind = 0x0 - kindDel = 0x1 - kindGet = 0x2 - kindSet = 0x3 + kindDel commandKind = 0x1 + kindGet commandKind = 0x2 + kindSet commandKind = 0x3 ) func typeKind(typ uint16) commandKind { @@ -423,6 +423,11 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin } attrs = rest + // NOTE: A netlink message will contain multiple header attributes. + // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent + // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the + // local interface address. We add the local interface address here + // and ignore the IFA_ADDRESS. switch ahdr.Type { case linux.IFA_LOCAL: err := stack.AddInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{ @@ -439,11 +444,60 @@ func (p *Protocol) newAddr(ctx context.Context, msg *netlink.Message, ms *netlin } else if err != nil { return syserr.ErrInvalidArgument } + case linux.IFA_ADDRESS: + default: + return syserr.ErrNotSupported } } return nil } +// delAddr handles RTM_DELADDR requests. +func (p *Protocol) delAddr(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { + stack := inet.StackFromContext(ctx) + if stack == nil { + // No network stack. + return syserr.ErrProtocolNotSupported + } + + var ifa linux.InterfaceAddrMessage + attrs, ok := msg.GetData(&ifa) + if !ok { + return syserr.ErrInvalidArgument + } + + for !attrs.Empty() { + ahdr, value, rest, ok := attrs.ParseFirst() + if !ok { + return syserr.ErrInvalidArgument + } + attrs = rest + + // NOTE: A netlink message will contain multiple header attributes. + // Both the IFA_ADDRESS and IFA_LOCAL attributes are typically sent + // with IFA_ADDRESS being a prefix address and IFA_LOCAL being the + // local interface address. We use the local interface address to + // remove the address and ignore the IFA_ADDRESS. + switch ahdr.Type { + case linux.IFA_LOCAL: + err := stack.RemoveInterfaceAddr(int32(ifa.Index), inet.InterfaceAddr{ + Family: ifa.Family, + PrefixLen: ifa.PrefixLen, + Flags: ifa.Flags, + Addr: value, + }) + if err != nil { + return syserr.ErrBadLocalAddress + } + case linux.IFA_ADDRESS: + default: + return syserr.ErrNotSupported + } + } + + return nil +} + // ProcessMessage implements netlink.Protocol.ProcessMessage. func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms *netlink.MessageSet) *syserr.Error { hdr := msg.Header() @@ -485,6 +539,8 @@ func (p *Protocol) ProcessMessage(ctx context.Context, msg *netlink.Message, ms return p.dumpRoutes(ctx, msg, ms) case linux.RTM_NEWADDR: return p.newAddr(ctx, msg, ms) + case linux.RTM_DELADDR: + return p.delAddr(ctx, msg, ms) default: return syserr.ErrNotSupported } diff --git a/pkg/sentry/socket/netlink/socket_vfs2.go b/pkg/sentry/socket/netlink/socket_vfs2.go index c83b23242..461d524e5 100644 --- a/pkg/sentry/socket/netlink/socket_vfs2.go +++ b/pkg/sentry/socket/netlink/socket_vfs2.go @@ -37,6 +37,8 @@ import ( // to/from the kernel. // // SocketVFS2 implements socket.SocketVFS2 and transport.Credentialer. +// +// +stateify savable type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl diff --git a/pkg/sentry/socket/netstack/netstack.go b/pkg/sentry/socket/netstack/netstack.go index 211f07947..86c634715 100644 --- a/pkg/sentry/socket/netstack/netstack.go +++ b/pkg/sentry/socket/netstack/netstack.go @@ -1244,6 +1244,18 @@ func getSockOptSocket(t *kernel.Task, s socket.SocketOps, ep commonEndpoint, fam vP := primitive.Int32(boolToInt32(v)) return &vP, nil + case linux.SO_ACCEPTCONN: + if outLen < sizeOfInt32 { + return nil, syserr.ErrInvalidArgument + } + + v, err := ep.GetSockOptBool(tcpip.AcceptConnOption) + if err != nil { + return nil, syserr.TranslateNetstackError(err) + } + vP := primitive.Int32(boolToInt32(v)) + return &vP, nil + default: socket.GetSockOptEmitUnimplementedEvent(t, name) } diff --git a/pkg/sentry/socket/netstack/netstack_vfs2.go b/pkg/sentry/socket/netstack/netstack_vfs2.go index 4c6791fff..b0d9e4d9e 100644 --- a/pkg/sentry/socket/netstack/netstack_vfs2.go +++ b/pkg/sentry/socket/netstack/netstack_vfs2.go @@ -35,6 +35,8 @@ import ( // SocketVFS2 encapsulates all the state needed to represent a network stack // endpoint in the kernel context. +// +// +stateify savable type SocketVFS2 struct { vfsfd vfs.FileDescription vfs.FileDescriptionDefaultImpl @@ -55,7 +57,7 @@ func NewVFS2(t *kernel.Task, family int, skType linux.SockType, protocol int, qu } mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) s := &SocketVFS2{ diff --git a/pkg/sentry/socket/netstack/stack.go b/pkg/sentry/socket/netstack/stack.go index 1028d2a6e..fa9ac9059 100644 --- a/pkg/sentry/socket/netstack/stack.go +++ b/pkg/sentry/socket/netstack/stack.go @@ -100,56 +100,101 @@ func (s *Stack) InterfaceAddrs() map[int32][]inet.InterfaceAddr { return nicAddrs } -// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. -func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { +// convertAddr converts an InterfaceAddr to a ProtocolAddress. +func convertAddr(addr inet.InterfaceAddr) (tcpip.ProtocolAddress, error) { var ( - protocol tcpip.NetworkProtocolNumber - address tcpip.Address + protocol tcpip.NetworkProtocolNumber + address tcpip.Address + protocolAddress tcpip.ProtocolAddress ) switch addr.Family { case linux.AF_INET: - if len(addr.Addr) < header.IPv4AddressSize { - return syserror.EINVAL + if len(addr.Addr) != header.IPv4AddressSize { + return protocolAddress, syserror.EINVAL } if addr.PrefixLen > header.IPv4AddressSize*8 { - return syserror.EINVAL + return protocolAddress, syserror.EINVAL } protocol = ipv4.ProtocolNumber - address = tcpip.Address(addr.Addr[:header.IPv4AddressSize]) - + address = tcpip.Address(addr.Addr) case linux.AF_INET6: - if len(addr.Addr) < header.IPv6AddressSize { - return syserror.EINVAL + if len(addr.Addr) != header.IPv6AddressSize { + return protocolAddress, syserror.EINVAL } if addr.PrefixLen > header.IPv6AddressSize*8 { - return syserror.EINVAL + return protocolAddress, syserror.EINVAL } protocol = ipv6.ProtocolNumber - address = tcpip.Address(addr.Addr[:header.IPv6AddressSize]) - + address = tcpip.Address(addr.Addr) default: - return syserror.ENOTSUP + return protocolAddress, syserror.ENOTSUP } - protocolAddress := tcpip.ProtocolAddress{ + protocolAddress = tcpip.ProtocolAddress{ Protocol: protocol, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: address, PrefixLen: int(addr.PrefixLen), }, } + return protocolAddress, nil +} + +// AddInterfaceAddr implements inet.Stack.AddInterfaceAddr. +func (s *Stack) AddInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + protocolAddress, err := convertAddr(addr) + if err != nil { + return err + } // Attach address to interface. - if err := s.Stack.AddProtocolAddressWithOptions(tcpip.NICID(idx), protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + nicID := tcpip.NICID(idx) + if err := s.Stack.AddProtocolAddressWithOptions(nicID, protocolAddress, stack.CanBePrimaryEndpoint); err != nil { + return syserr.TranslateNetstackError(err).ToError() + } + + // Add route for local network if it doesn't exist already. + localRoute := tcpip.Route{ + Destination: protocolAddress.AddressWithPrefix.Subnet(), + Gateway: "", // No gateway for local network. + NIC: nicID, + } + + for _, rt := range s.Stack.GetRouteTable() { + if rt.Equal(localRoute) { + return nil + } + } + + // Local route does not exist yet. Add it. + s.Stack.AddRoute(localRoute) + + return nil +} + +// RemoveInterfaceAddr implements inet.Stack.RemoveInterfaceAddr. +func (s *Stack) RemoveInterfaceAddr(idx int32, addr inet.InterfaceAddr) error { + protocolAddress, err := convertAddr(addr) + if err != nil { + return err + } + + // Remove addresses matching the address and prefix. + nicID := tcpip.NICID(idx) + if err := s.Stack.RemoveAddress(nicID, protocolAddress.AddressWithPrefix.Address); err != nil { return syserr.TranslateNetstackError(err).ToError() } - // Add route for local network. - s.Stack.AddRoute(tcpip.Route{ + // Remove the corresponding local network route if it exists. + localRoute := tcpip.Route{ Destination: protocolAddress.AddressWithPrefix.Subnet(), Gateway: "", // No gateway for local network. - NIC: tcpip.NICID(idx), + NIC: nicID, + } + s.Stack.RemoveRoutes(func(rt tcpip.Route) bool { + return rt.Equal(localRoute) }) + return nil } diff --git a/pkg/sentry/socket/unix/BUILD b/pkg/sentry/socket/unix/BUILD index cc7408698..cce0acc33 100644 --- a/pkg/sentry/socket/unix/BUILD +++ b/pkg/sentry/socket/unix/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "socket_refs.go", package = "unix", prefix = "socketOperations", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "SocketOperations", }, @@ -19,7 +19,7 @@ go_template_instance( out = "socket_vfs2_refs.go", package = "unix", prefix = "socketVFS2", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "SocketVFS2", }, @@ -43,6 +43,7 @@ go_library( "//pkg/log", "//pkg/marshal", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/device", diff --git a/pkg/sentry/socket/unix/transport/BUILD b/pkg/sentry/socket/unix/transport/BUILD index 26c3a51b9..3ebbd28b0 100644 --- a/pkg/sentry/socket/unix/transport/BUILD +++ b/pkg/sentry/socket/unix/transport/BUILD @@ -20,7 +20,7 @@ go_template_instance( out = "queue_refs.go", package = "transport", prefix = "queue", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "queue", }, @@ -44,6 +44,7 @@ go_library( "//pkg/ilist", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sync", "//pkg/syserr", "//pkg/tcpip", diff --git a/pkg/sentry/socket/unix/transport/unix.go b/pkg/sentry/socket/unix/transport/unix.go index d6fc03520..b648273a4 100644 --- a/pkg/sentry/socket/unix/transport/unix.go +++ b/pkg/sentry/socket/unix/transport/unix.go @@ -32,6 +32,8 @@ import ( const initialLimit = 16 * 1024 // A RightsControlMessage is a control message containing FDs. +// +// +stateify savable type RightsControlMessage interface { // Clone returns a copy of the RightsControlMessage. Clone() RightsControlMessage @@ -336,7 +338,7 @@ type Receiver interface { RecvMaxQueueSize() int64 // Release releases any resources owned by the Receiver. It should be - // called before droping all references to a Receiver. + // called before dropping all references to a Receiver. Release(ctx context.Context) } @@ -487,7 +489,7 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds c := q.control.Clone() // Don't consume data since we are peeking. - copied, data, _ = vecCopy(data, q.buffer) + copied, _, _ = vecCopy(data, q.buffer) return copied, copied, c, false, q.addr, notify, nil } @@ -572,6 +574,12 @@ func (q *streamQueueReceiver) Recv(ctx context.Context, data [][]byte, wantCreds return copied, copied, c, cmTruncated, q.addr, notify, nil } +// Release implements Receiver.Release. +func (q *streamQueueReceiver) Release(ctx context.Context) { + q.queueReceiver.Release(ctx) + q.control.Release(ctx) +} + // A ConnectedEndpoint is an Endpoint that can be used to send Messages. type ConnectedEndpoint interface { // Passcred implements Endpoint.Passcred. @@ -619,7 +627,7 @@ type ConnectedEndpoint interface { SendMaxQueueSize() int64 // Release releases any resources owned by the ConnectedEndpoint. It should - // be called before droping all references to a ConnectedEndpoint. + // be called before dropping all references to a ConnectedEndpoint. Release(ctx context.Context) // CloseUnread sets the fact that this end is closed with unread data to @@ -879,7 +887,7 @@ func (e *baseEndpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { func (e *baseEndpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { - case tcpip.KeepaliveEnabledOption: + case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: return false, nil case tcpip.PasscredOption: diff --git a/pkg/sentry/socket/unix/unix.go b/pkg/sentry/socket/unix/unix.go index a4a76d0a3..adad485a9 100644 --- a/pkg/sentry/socket/unix/unix.go +++ b/pkg/sentry/socket/unix/unix.go @@ -81,7 +81,6 @@ func NewWithDirent(ctx context.Context, d *fs.Dirent, ep transport.Endpoint, sty }, } s.EnableLeakCheck() - return fs.NewFile(ctx, d, flags, &s) } diff --git a/pkg/sentry/socket/unix/unix_vfs2.go b/pkg/sentry/socket/unix/unix_vfs2.go index 678355fb9..7a78444dc 100644 --- a/pkg/sentry/socket/unix/unix_vfs2.go +++ b/pkg/sentry/socket/unix/unix_vfs2.go @@ -55,7 +55,7 @@ var _ = socket.SocketVFS2(&SocketVFS2{}) // returns a corresponding file description. func NewSockfsFile(t *kernel.Task, ep transport.Endpoint, stype linux.SockType) (*vfs.FileDescription, *syserr.Error) { mnt := t.Kernel().SocketMount() - d := sockfs.NewDentry(t.Credentials(), mnt) + d := sockfs.NewDentry(t, mnt) defer d.DecRef(t) fd, err := NewFileDescription(ep, stype, linux.O_RDWR, mnt, d, &vfs.FileLocks{}) @@ -80,6 +80,7 @@ func NewFileDescription(ep transport.Endpoint, stype linux.SockType, flags uint3 stype: stype, }, } + sock.EnableLeakCheck() sock.LockFD.Init(locks) vfsfd := &sock.vfsfd if err := vfsfd.Init(sock, flags, mnt, d, &vfs.FileDescriptionOptions{ diff --git a/pkg/sentry/state/BUILD b/pkg/sentry/state/BUILD index 0ea4aab8b..563d60578 100644 --- a/pkg/sentry/state/BUILD +++ b/pkg/sentry/state/BUILD @@ -12,10 +12,12 @@ go_library( visibility = ["//pkg/sentry:internal"], deps = [ "//pkg/abi/linux", + "//pkg/context", "//pkg/log", "//pkg/sentry/inet", "//pkg/sentry/kernel", "//pkg/sentry/time", + "//pkg/sentry/vfs", "//pkg/sentry/watchdog", "//pkg/state/statefile", "//pkg/syserror", diff --git a/pkg/sentry/state/state.go b/pkg/sentry/state/state.go index 245d2c5cf..167754537 100644 --- a/pkg/sentry/state/state.go +++ b/pkg/sentry/state/state.go @@ -19,10 +19,12 @@ import ( "fmt" "io" + "gvisor.dev/gvisor/pkg/context" "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/sentry/inet" "gvisor.dev/gvisor/pkg/sentry/kernel" "gvisor.dev/gvisor/pkg/sentry/time" + "gvisor.dev/gvisor/pkg/sentry/vfs" "gvisor.dev/gvisor/pkg/sentry/watchdog" "gvisor.dev/gvisor/pkg/state/statefile" "gvisor.dev/gvisor/pkg/syserror" @@ -57,7 +59,7 @@ type SaveOpts struct { } // Save saves the system state. -func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error { +func (opts SaveOpts) Save(ctx context.Context, k *kernel.Kernel, w *watchdog.Watchdog) error { log.Infof("Sandbox save started, pausing all tasks.") k.Pause() k.ReceiveTaskStates() @@ -81,7 +83,7 @@ func (opts SaveOpts) Save(k *kernel.Kernel, w *watchdog.Watchdog) error { err = ErrStateFile{err} } else { // Save the kernel. - err = k.SaveTo(wc) + err = k.SaveTo(ctx, wc) // ENOSPC is a state file error. This error can only come from // writing the state file, and not from fs.FileOperations.Fsync @@ -108,7 +110,7 @@ type LoadOpts struct { } // Load loads the given kernel, setting the provided platform and stack. -func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) error { +func (opts LoadOpts) Load(ctx context.Context, k *kernel.Kernel, n inet.Stack, clocks time.Clocks, vfsOpts *vfs.CompleteRestoreOptions) error { // Open the file. r, m, err := statefile.NewReader(opts.Source, opts.Key) if err != nil { @@ -118,5 +120,5 @@ func (opts LoadOpts) Load(k *kernel.Kernel, n inet.Stack, clocks time.Clocks) er previousMetadata = m // Restore the Kernel object graph. - return k.LoadFrom(r, n, clocks) + return k.LoadFrom(ctx, r, n, clocks, vfsOpts) } diff --git a/pkg/sentry/syscalls/linux/linux64.go b/pkg/sentry/syscalls/linux/linux64.go index 9c9def7cd..bb1f715e2 100644 --- a/pkg/sentry/syscalls/linux/linux64.go +++ b/pkg/sentry/syscalls/linux/linux64.go @@ -118,7 +118,7 @@ var AMD64 = &kernel.SyscallTable{ 63: syscalls.Supported("uname", Uname), 64: syscalls.Supported("semget", Semget), 65: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), - 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 66: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), 67: syscalls.Supported("shmdt", Shmdt), 68: syscalls.ErrorWithEvent("msgget", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 69: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) @@ -619,7 +619,7 @@ var ARM64 = &kernel.SyscallTable{ 188: syscalls.ErrorWithEvent("msgrcv", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 189: syscalls.ErrorWithEvent("msgsnd", syserror.ENOSYS, "", []string{"gvisor.dev/issue/135"}), // TODO(b/29354921) 190: syscalls.Supported("semget", Semget), - 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, IPC_STAT, SEM_STAT, SEM_STAT_ANY, GETNCNT, GETZCNT not supported.", nil), + 191: syscalls.PartiallySupported("semctl", Semctl, "Options IPC_INFO, SEM_INFO, SEM_STAT, SEM_STAT_ANY not supported.", nil), 192: syscalls.ErrorWithEvent("semtimedop", syserror.ENOSYS, "", []string{"gvisor.dev/issue/137"}), 193: syscalls.PartiallySupported("semop", Semop, "Option SEM_UNDO not supported.", nil), 194: syscalls.PartiallySupported("shmget", Shmget, "Option SHM_HUGETLB is not supported.", nil), diff --git a/pkg/sentry/syscalls/linux/sys_pipe.go b/pkg/sentry/syscalls/linux/sys_pipe.go index 849a47476..f7135ea46 100644 --- a/pkg/sentry/syscalls/linux/sys_pipe.go +++ b/pkg/sentry/syscalls/linux/sys_pipe.go @@ -32,7 +32,7 @@ func pipe2(t *kernel.Task, addr usermem.Addr, flags uint) (uintptr, error) { if flags&^(linux.O_NONBLOCK|linux.O_CLOEXEC) != 0 { return 0, syserror.EINVAL } - r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize, usermem.PageSize) + r, w := pipe.NewConnectedPipe(t, pipe.DefaultPipeSize) r.SetFlags(linuxToFlags(flags).Settable()) defer r.DecRef(t) diff --git a/pkg/sentry/syscalls/linux/sys_sem.go b/pkg/sentry/syscalls/linux/sys_sem.go index 47dadb800..e383a0a87 100644 --- a/pkg/sentry/syscalls/linux/sys_sem.go +++ b/pkg/sentry/syscalls/linux/sys_sem.go @@ -129,13 +129,27 @@ func Semctl(t *kernel.Task, args arch.SyscallArguments) (uintptr, *kernel.Syscal v, err := getPID(t, id, num) return uintptr(v), nil, err + case linux.IPC_STAT: + arg := args[3].Pointer() + ds, err := ipcStat(t, id) + if err == nil { + _, err = ds.CopyOut(t, arg) + } + + return 0, nil, err + + case linux.GETZCNT: + v, err := getZCnt(t, id, num) + return uintptr(v), nil, err + + case linux.GETNCNT: + v, err := getNCnt(t, id, num) + return uintptr(v), nil, err + case linux.IPC_INFO, linux.SEM_INFO, - linux.IPC_STAT, linux.SEM_STAT, - linux.SEM_STAT_ANY, - linux.GETNCNT, - linux.GETZCNT: + linux.SEM_STAT_ANY: t.Kernel().EmitUnimplementedEvent(t) fallthrough @@ -171,6 +185,16 @@ func ipcSet(t *kernel.Task, id int32, uid auth.UID, gid auth.GID, perms fs.FileP return set.Change(t, creds, owner, perms) } +func ipcStat(t *kernel.Task, id int32) (*linux.SemidDS, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByID(id) + if set == nil { + return nil, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + return set.GetStat(creds) +} + func setVal(t *kernel.Task, id int32, num int32, val int16) error { r := t.IPCNamespace().SemaphoreRegistry() set := r.FindByID(id) @@ -240,3 +264,23 @@ func getPID(t *kernel.Task, id int32, num int32) (int32, error) { } return int32(tg.ID()), nil } + +func getZCnt(t *kernel.Task, id int32, num int32) (uint16, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByID(id) + if set == nil { + return 0, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + return set.CountZeroWaiters(num, creds) +} + +func getNCnt(t *kernel.Task, id int32, num int32) (uint16, error) { + r := t.IPCNamespace().SemaphoreRegistry() + set := r.FindByID(id) + if set == nil { + return 0, syserror.EINVAL + } + creds := auth.CredentialsFromContext(t) + return set.CountNegativeWaiters(num, creds) +} diff --git a/pkg/sentry/syscalls/linux/sys_splice.go b/pkg/sentry/syscalls/linux/sys_splice.go index 46616c961..1c4cdb0dd 100644 --- a/pkg/sentry/syscalls/linux/sys_splice.go +++ b/pkg/sentry/syscalls/linux/sys_splice.go @@ -41,6 +41,7 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB inCh chan struct{} outCh chan struct{} ) + for opts.Length > 0 { n, err = fs.Splice(t, outFile, inFile, opts) opts.Length -= n @@ -61,23 +62,28 @@ func doSplice(t *kernel.Task, outFile, inFile *fs.File, opts fs.SpliceOpts, nonB inW, _ := waiter.NewChannelEntry(inCh) inFile.EventRegister(&inW, EventMaskRead) defer inFile.EventUnregister(&inW) - continue // Need to refresh readiness. + // Need to refresh readiness. + continue } if err = t.Block(inCh); err != nil { break } } - if outFile.Readiness(EventMaskWrite) == 0 { - if outCh == nil { - outCh = make(chan struct{}, 1) - outW, _ := waiter.NewChannelEntry(outCh) - outFile.EventRegister(&outW, EventMaskWrite) - defer outFile.EventUnregister(&outW) - continue // Need to refresh readiness. - } - if err = t.Block(outCh); err != nil { - break - } + // Don't bother checking readiness of the outFile, because it's not a + // guarantee that it won't return EWOULDBLOCK. Both pipes and eventfds + // can be "ready" but will reject writes of certain sizes with + // EWOULDBLOCK. + if outCh == nil { + outCh = make(chan struct{}, 1) + outW, _ := waiter.NewChannelEntry(outCh) + outFile.EventRegister(&outW, EventMaskWrite) + defer outFile.EventUnregister(&outW) + // We might be ready to write now. Try again before + // blocking. + continue + } + if err = t.Block(outCh); err != nil { + break } } diff --git a/pkg/sentry/syscalls/linux/vfs2/splice.go b/pkg/sentry/syscalls/linux/vfs2/splice.go index 035e2a6b0..9ce4f280a 100644 --- a/pkg/sentry/syscalls/linux/vfs2/splice.go +++ b/pkg/sentry/syscalls/linux/vfs2/splice.go @@ -480,18 +480,17 @@ func (dw *dualWaiter) waitForBoth(t *kernel.Task) error { // waitForOut waits for dw.outfile to be read. func (dw *dualWaiter) waitForOut(t *kernel.Task) error { - if dw.outFile.Readiness(eventMaskWrite)&eventMaskWrite == 0 { - if dw.outCh == nil { - dw.outW, dw.outCh = waiter.NewChannelEntry(nil) - dw.outFile.EventRegister(&dw.outW, eventMaskWrite) - // We might be ready now. Try again before blocking. - return nil - } - if err := t.Block(dw.outCh); err != nil { - return err - } - } - return nil + // Don't bother checking readiness of the outFile, because it's not a + // guarantee that it won't return EWOULDBLOCK. Both pipes and eventfds + // can be "ready" but will reject writes of certain sizes with + // EWOULDBLOCK. See b/172075629, b/170743336. + if dw.outCh == nil { + dw.outW, dw.outCh = waiter.NewChannelEntry(nil) + dw.outFile.EventRegister(&dw.outW, eventMaskWrite) + // We might be ready to write now. Try again before blocking. + return nil + } + return t.Block(dw.outCh) } // destroy cleans up resources help by dw. No more calls to wait* can occur diff --git a/pkg/sentry/vfs/BUILD b/pkg/sentry/vfs/BUILD index c855608db..440c9307c 100644 --- a/pkg/sentry/vfs/BUILD +++ b/pkg/sentry/vfs/BUILD @@ -32,7 +32,7 @@ go_template_instance( out = "file_description_refs.go", package = "vfs", prefix = "FileDescription", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "FileDescription", }, @@ -43,7 +43,7 @@ go_template_instance( out = "mount_namespace_refs.go", package = "vfs", prefix = "MountNamespace", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "MountNamespace", }, @@ -54,7 +54,7 @@ go_template_instance( out = "filesystem_refs.go", package = "vfs", prefix = "Filesystem", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "Filesystem", }, @@ -87,6 +87,7 @@ go_library( "pathname.go", "permissions.go", "resolving_path.go", + "save_restore.go", "vfs.go", ], visibility = ["//pkg/sentry:internal"], @@ -99,6 +100,7 @@ go_library( "//pkg/gohacks", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/safemem", "//pkg/sentry/arch", "//pkg/sentry/fs", diff --git a/pkg/sentry/vfs/epoll.go b/pkg/sentry/vfs/epoll.go index 8f36c3e3b..a98aac52b 100644 --- a/pkg/sentry/vfs/epoll.go +++ b/pkg/sentry/vfs/epoll.go @@ -74,7 +74,7 @@ type epollInterestKey struct { // +stateify savable type epollInterest struct { // epoll is the owning EpollInstance. epoll is immutable. - epoll *EpollInstance + epoll *EpollInstance `state:"wait"` // key is the file to which this epollInterest applies. key is immutable. key epollInterestKey diff --git a/pkg/sentry/vfs/file_description.go b/pkg/sentry/vfs/file_description.go index 183957ad8..546e445aa 100644 --- a/pkg/sentry/vfs/file_description.go +++ b/pkg/sentry/vfs/file_description.go @@ -183,7 +183,6 @@ func (fd *FileDescription) DecRef(ctx context.Context) { } fd.vd.DecRef(ctx) fd.flagsMu.Lock() - // TODO(gvisor.dev/issue/1663): We may need to unregister during save, as we do in VFS1. if fd.statusFlags&linux.O_ASYNC != 0 && fd.asyncHandler != nil { fd.asyncHandler.Unregister(fd) } diff --git a/pkg/sentry/vfs/genericfstree/genericfstree.go b/pkg/sentry/vfs/genericfstree/genericfstree.go index 2d27d9d35..ba6e6ed49 100644 --- a/pkg/sentry/vfs/genericfstree/genericfstree.go +++ b/pkg/sentry/vfs/genericfstree/genericfstree.go @@ -71,7 +71,7 @@ func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath if mnt == vfsroot.Mount() && &d.vfsd == vfsroot.Dentry() { return vfs.PrependPathAtVFSRootError{} } - if &d.vfsd == mnt.Root() { + if mnt != nil && &d.vfsd == mnt.Root() { return nil } if d.parent == nil { @@ -81,3 +81,12 @@ func PrependPath(vfsroot vfs.VirtualDentry, mnt *vfs.Mount, d *Dentry, b *fspath d = d.parent } } + +// DebugPathname returns a pathname to d relative to its filesystem root. +// DebugPathname does not correspond to any Linux function; it's used to +// generate dentry pathnames for debugging. +func DebugPathname(d *Dentry) string { + var b fspath.Builder + _ = PrependPath(vfs.VirtualDentry{}, nil, d, &b) + return b.String() +} diff --git a/pkg/sentry/vfs/inotify.go b/pkg/sentry/vfs/inotify.go index 3f0b8f45b..107171b61 100644 --- a/pkg/sentry/vfs/inotify.go +++ b/pkg/sentry/vfs/inotify.go @@ -65,7 +65,7 @@ type Inotify struct { // queue is used to notify interested parties when the inotify instance // becomes readable or writable. - queue waiter.Queue `state:"nosave"` + queue waiter.Queue // evMu *only* protects the events list. We need a separate lock while // queuing events: using mu may violate lock ordering, since at that point diff --git a/pkg/sentry/vfs/lock.go b/pkg/sentry/vfs/lock.go index 55783d4eb..1ff202f2a 100644 --- a/pkg/sentry/vfs/lock.go +++ b/pkg/sentry/vfs/lock.go @@ -12,11 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package lock provides POSIX and BSD style file locking for VFS2 file -// implementations. -// -// The actual implementations can be found in the lock package under -// sentry/fs/lock. package vfs import ( diff --git a/pkg/sentry/vfs/mount.go b/pkg/sentry/vfs/mount.go index 78f115bfa..3ea981ad4 100644 --- a/pkg/sentry/vfs/mount.go +++ b/pkg/sentry/vfs/mount.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/abi/linux" "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/refsvfs2" "gvisor.dev/gvisor/pkg/sentry/kernel/auth" "gvisor.dev/gvisor/pkg/syserror" ) @@ -106,6 +107,7 @@ func newMount(vfs *VirtualFilesystem, fs *Filesystem, root *Dentry, mntns *Mount if opts.ReadOnly { mnt.setReadOnlyLocked(true) } + refsvfs2.Register(mnt) return mnt } @@ -470,11 +472,12 @@ func (vfs *VirtualFilesystem) disconnectLocked(mnt *Mount) VirtualDentry { // tryIncMountedRef does not require that a reference is held on mnt. func (mnt *Mount) tryIncMountedRef() bool { for { - refs := atomic.LoadInt64(&mnt.refs) - if refs <= 0 { // refs < 0 => MSB set => eagerly unmounted + r := atomic.LoadInt64(&mnt.refs) + if r <= 0 { // r < 0 => MSB set => eagerly unmounted return false } - if atomic.CompareAndSwapInt64(&mnt.refs, refs, refs+1) { + if atomic.CompareAndSwapInt64(&mnt.refs, r, r+1) { + refsvfs2.LogTryIncRef(mnt, r+1) return true } } @@ -484,29 +487,53 @@ func (mnt *Mount) tryIncMountedRef() bool { func (mnt *Mount) IncRef() { // In general, negative values for mnt.refs are valid because the MSB is // the eager-unmount bit. - atomic.AddInt64(&mnt.refs, 1) + r := atomic.AddInt64(&mnt.refs, 1) + refsvfs2.LogIncRef(mnt, r) } // DecRef decrements mnt's reference count. func (mnt *Mount) DecRef(ctx context.Context) { - refs := atomic.AddInt64(&mnt.refs, -1) - if refs&^math.MinInt64 == 0 { // mask out MSB - var vd VirtualDentry - if mnt.parent() != nil { - mnt.vfs.mountMu.Lock() - mnt.vfs.mounts.seq.BeginWrite() - vd = mnt.vfs.disconnectLocked(mnt) - mnt.vfs.mounts.seq.EndWrite() - mnt.vfs.mountMu.Unlock() - } - if mnt.root != nil { - mnt.root.DecRef(ctx) - } - mnt.fs.DecRef(ctx) - if vd.Ok() { - vd.DecRef(ctx) - } + r := atomic.AddInt64(&mnt.refs, -1) + if r&^math.MinInt64 == 0 { // mask out MSB + refsvfs2.Unregister(mnt) + mnt.destroy(ctx) + } +} + +func (mnt *Mount) destroy(ctx context.Context) { + var vd VirtualDentry + if mnt.parent() != nil { + mnt.vfs.mountMu.Lock() + mnt.vfs.mounts.seq.BeginWrite() + vd = mnt.vfs.disconnectLocked(mnt) + mnt.vfs.mounts.seq.EndWrite() + mnt.vfs.mountMu.Unlock() + } + if mnt.root != nil { + mnt.root.DecRef(ctx) } + mnt.fs.DecRef(ctx) + if vd.Ok() { + vd.DecRef(ctx) + } +} + +// RefType implements refsvfs2.CheckedObject.Type. +func (mnt *Mount) RefType() string { + return "vfs.Mount" +} + +// LeakMessage implements refsvfs2.CheckedObject.LeakMessage. +func (mnt *Mount) LeakMessage() string { + return fmt.Sprintf("[vfs.Mount %p] reference count of %d instead of 0", mnt, atomic.LoadInt64(&mnt.refs)) +} + +// LogRefs implements refsvfs2.CheckedObject.LogRefs. +// +// This should only be set to true for debugging purposes, as it can generate an +// extremely large amount of output and drastically degrade performance. +func (mnt *Mount) LogRefs() bool { + return false } // DecRef decrements mntns' reference count. diff --git a/pkg/sentry/vfs/mount_test.go b/pkg/sentry/vfs/mount_test.go index cb8c56bd3..cb882a983 100644 --- a/pkg/sentry/vfs/mount_test.go +++ b/pkg/sentry/vfs/mount_test.go @@ -29,7 +29,7 @@ func TestMountTableLookupEmpty(t *testing.T) { parent := &Mount{} point := &Dentry{} if m := mt.Lookup(parent, point); m != nil { - t.Errorf("empty mountTable lookup: got %p, wanted nil", m) + t.Errorf("Empty mountTable lookup: got %p, wanted nil", m) } } @@ -111,13 +111,16 @@ func BenchmarkMountTableParallelLookup(b *testing.B) { k := keys[i&(numMounts-1)] m := mt.Lookup(k.mount, k.dentry) if m == nil { - b.Fatalf("lookup failed") + b.Errorf("Lookup failed") + return } if parent := m.parent(); parent != k.mount { - b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount) + b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount) + return } if point := m.point(); point != k.dentry { - b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry) + b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry) + return } } }() @@ -167,13 +170,16 @@ func BenchmarkMountMapParallelLookup(b *testing.B) { m := ms[k] mu.RUnlock() if m == nil { - b.Fatalf("lookup failed") + b.Errorf("Lookup failed") + return } if parent := m.parent(); parent != k.mount { - b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount) + b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount) + return } if point := m.point(); point != k.dentry { - b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry) + b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry) + return } } }() @@ -220,14 +226,17 @@ func BenchmarkMountSyncMapParallelLookup(b *testing.B) { k := keys[i&(numMounts-1)] mi, ok := ms.Load(k) if !ok { - b.Fatalf("lookup failed") + b.Errorf("Lookup failed") + return } m := mi.(*Mount) if parent := m.parent(); parent != k.mount { - b.Fatalf("lookup returned mount with parent %p, wanted %p", parent, k.mount) + b.Errorf("Lookup returned mount with parent %p, wanted %p", parent, k.mount) + return } if point := m.point(); point != k.dentry { - b.Fatalf("lookup returned mount with point %p, wanted %p", point, k.dentry) + b.Errorf("Lookup returned mount with point %p, wanted %p", point, k.dentry) + return } } }() @@ -264,7 +273,7 @@ func BenchmarkMountTableNegativeLookup(b *testing.B) { k := negkeys[i&(numMounts-1)] m := mt.Lookup(k.mount, k.dentry) if m != nil { - b.Fatalf("lookup got %p, wanted nil", m) + b.Fatalf("Lookup got %p, wanted nil", m) } } }) @@ -300,7 +309,7 @@ func BenchmarkMountMapNegativeLookup(b *testing.B) { m := ms[k] mu.RUnlock() if m != nil { - b.Fatalf("lookup got %p, wanted nil", m) + b.Fatalf("Lookup got %p, wanted nil", m) } } }) @@ -333,7 +342,7 @@ func BenchmarkMountSyncMapNegativeLookup(b *testing.B) { k := negkeys[i&(numMounts-1)] m, _ := ms.Load(k) if m != nil { - b.Fatalf("lookup got %p, wanted nil", m) + b.Fatalf("Lookup got %p, wanted nil", m) } } }) diff --git a/pkg/sentry/vfs/mount_unsafe.go b/pkg/sentry/vfs/mount_unsafe.go index b7d122d22..cb48c37a1 100644 --- a/pkg/sentry/vfs/mount_unsafe.go +++ b/pkg/sentry/vfs/mount_unsafe.go @@ -98,7 +98,6 @@ type mountTable struct { // length and cap in separate uint32s) for ~free. size uint64 - // FIXME(gvisor.dev/issue/1663): Slots need to be saved. slots unsafe.Pointer `state:"nosave"` // []mountSlot; never nil after Init } @@ -212,6 +211,26 @@ loop: } } +// Range calls f on each Mount in mt. If f returns false, Range stops iteration +// and returns immediately. +func (mt *mountTable) Range(f func(*Mount) bool) { + tcap := uintptr(1) << (mt.size & mtSizeOrderMask) + slotPtr := mt.slots + last := unsafe.Pointer(uintptr(mt.slots) + ((tcap - 1) * mountSlotBytes)) + for { + slot := (*mountSlot)(slotPtr) + if slot.value != nil { + if !f((*Mount)(slot.value)) { + return + } + } + if slotPtr == last { + return + } + slotPtr = unsafe.Pointer(uintptr(slotPtr) + mountSlotBytes) + } +} + // Insert inserts the given mount into mt. // // Preconditions: mt must not already contain a Mount with the same mount point diff --git a/pkg/sentry/vfs/save_restore.go b/pkg/sentry/vfs/save_restore.go new file mode 100644 index 000000000..7723ed643 --- /dev/null +++ b/pkg/sentry/vfs/save_restore.go @@ -0,0 +1,124 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package vfs + +import ( + "fmt" + "sync/atomic" + + "gvisor.dev/gvisor/pkg/context" + "gvisor.dev/gvisor/pkg/refsvfs2" +) + +// FilesystemImplSaveRestoreExtension is an optional extension to +// FilesystemImpl. +type FilesystemImplSaveRestoreExtension interface { + // PrepareSave prepares this filesystem for serialization. + PrepareSave(ctx context.Context) error + + // CompleteRestore completes restoration from checkpoint for this + // filesystem after deserialization. + CompleteRestore(ctx context.Context, opts CompleteRestoreOptions) error +} + +// PrepareSave prepares all filesystems for serialization. +func (vfs *VirtualFilesystem) PrepareSave(ctx context.Context) error { + failures := 0 + for fs := range vfs.getFilesystems() { + if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok { + if err := ext.PrepareSave(ctx); err != nil { + ctx.Warningf("%T.PrepareSave failed: %v", fs.impl, err) + failures++ + } + } + fs.DecRef(ctx) + } + if failures != 0 { + return fmt.Errorf("%d filesystems failed to prepare for serialization", failures) + } + return nil +} + +// CompleteRestore completes restoration from checkpoint for all filesystems +// after deserialization. +func (vfs *VirtualFilesystem) CompleteRestore(ctx context.Context, opts *CompleteRestoreOptions) error { + failures := 0 + for fs := range vfs.getFilesystems() { + if ext, ok := fs.impl.(FilesystemImplSaveRestoreExtension); ok { + if err := ext.CompleteRestore(ctx, *opts); err != nil { + ctx.Warningf("%T.CompleteRestore failed: %v", fs.impl, err) + failures++ + } + } + fs.DecRef(ctx) + } + if failures != 0 { + return fmt.Errorf("%d filesystems failed to complete restore after deserialization", failures) + } + return nil +} + +// CompleteRestoreOptions contains options to +// VirtualFilesystem.CompleteRestore() and +// FilesystemImplSaveRestoreExtension.CompleteRestore(). +type CompleteRestoreOptions struct { + // If ValidateFileSizes is true, filesystem implementations backed by + // remote filesystems should verify that file sizes have not changed + // between checkpoint and restore. + ValidateFileSizes bool + + // If ValidateFileModificationTimestamps is true, filesystem + // implementations backed by remote filesystems should validate that file + // mtimes have not changed between checkpoint and restore. + ValidateFileModificationTimestamps bool +} + +// saveMounts is called by stateify. +func (vfs *VirtualFilesystem) saveMounts() []*Mount { + if atomic.LoadPointer(&vfs.mounts.slots) == nil { + // vfs.Init() was never called. + return nil + } + var mounts []*Mount + vfs.mounts.Range(func(mount *Mount) bool { + mounts = append(mounts, mount) + return true + }) + return mounts +} + +// loadMounts is called by stateify. +func (vfs *VirtualFilesystem) loadMounts(mounts []*Mount) { + if mounts == nil { + return + } + vfs.mounts.Init() + for _, mount := range mounts { + vfs.mounts.Insert(mount) + } +} + +func (mnt *Mount) afterLoad() { + if atomic.LoadInt64(&mnt.refs) != 0 { + refsvfs2.Register(mnt) + } +} + +// afterLoad is called by stateify. +func (epi *epollInterest) afterLoad() { + // Mark all epollInterests as ready after restore so that the next call to + // EpollInstance.ReadEvents() rechecks their readiness. + epi.Callback(nil) +} diff --git a/pkg/sentry/vfs/vfs.go b/pkg/sentry/vfs/vfs.go index 38d2701d2..48d6252f7 100644 --- a/pkg/sentry/vfs/vfs.go +++ b/pkg/sentry/vfs/vfs.go @@ -71,7 +71,7 @@ type VirtualFilesystem struct { // points. // // mounts is analogous to Linux's mount_hashtable. - mounts mountTable + mounts mountTable `state:".([]*Mount)"` // mountpoints maps mount points to mounts at those points in all // namespaces. mountpoints is protected by mountMu. @@ -780,23 +780,27 @@ func (vfs *VirtualFilesystem) RemoveXattrAt(ctx context.Context, creds *auth.Cre // SyncAllFilesystems has the semantics of Linux's sync(2). func (vfs *VirtualFilesystem) SyncAllFilesystems(ctx context.Context) error { + var retErr error + for fs := range vfs.getFilesystems() { + if err := fs.impl.Sync(ctx); err != nil && retErr == nil { + retErr = err + } + fs.DecRef(ctx) + } + return retErr +} + +func (vfs *VirtualFilesystem) getFilesystems() map[*Filesystem]struct{} { fss := make(map[*Filesystem]struct{}) vfs.filesystemsMu.Lock() + defer vfs.filesystemsMu.Unlock() for fs := range vfs.filesystems { if !fs.TryIncRef() { continue } fss[fs] = struct{}{} } - vfs.filesystemsMu.Unlock() - var retErr error - for fs := range fss { - if err := fs.impl.Sync(ctx); err != nil && retErr == nil { - retErr = err - } - fs.DecRef(ctx) - } - return retErr + return fss } // MkdirAllAt recursively creates non-existent directories on the given path diff --git a/pkg/shim/runsc/BUILD b/pkg/shim/runsc/BUILD index f08599ebd..cb0001852 100644 --- a/pkg/shim/runsc/BUILD +++ b/pkg/shim/runsc/BUILD @@ -10,6 +10,7 @@ go_library( ], visibility = ["//:sandbox"], deps = [ + "@com_github_containerd_containerd//log:go_default_library", "@com_github_containerd_go_runc//:go_default_library", "@com_github_opencontainers_runtime_spec//specs-go:go_default_library", ], diff --git a/pkg/shim/runsc/runsc.go b/pkg/shim/runsc/runsc.go index c5cf68efa..e7c9640ba 100644 --- a/pkg/shim/runsc/runsc.go +++ b/pkg/shim/runsc/runsc.go @@ -28,10 +28,12 @@ import ( "syscall" "time" + "github.com/containerd/containerd/log" runc "github.com/containerd/go-runc" specs "github.com/opencontainers/runtime-spec/specs-go" ) +// Monitor is the default process monitor to be used by runsc. var Monitor runc.ProcessMonitor = runc.Monitor // DefaultCommand is the default command for Runsc. @@ -74,6 +76,7 @@ func (r *Runsc) State(context context.Context, id string) (*runc.Container, erro return &c, nil } +// CreateOpts is a set of options to Runsc.Create(). type CreateOpts struct { runc.IO ConsoleSocket runc.ConsoleSocket @@ -197,6 +200,7 @@ func (r *Runsc) Wait(context context.Context, id string) (int, error) { return res.ExitStatus, nil } +// ExecOpts is a set of options to runsc.Exec(). type ExecOpts struct { runc.IO PidFile string @@ -301,6 +305,7 @@ func (r *Runsc) Run(context context.Context, id, bundle string, opts *CreateOpts return Monitor.Wait(cmd, ec) } +// DeleteOpts is a set of options to runsc.Delete(). type DeleteOpts struct { Force bool } @@ -367,6 +372,13 @@ func (r *Runsc) Stats(context context.Context, id string) (*runc.Stats, error) { if err := json.NewDecoder(rd).Decode(&e); err != nil { return nil, err } + log.L.Debugf("Stats returned: %+v", e.Stats) + if e.Type != "stats" { + return nil, fmt.Errorf(`unexpected event type %q, wanted "stats"`, e.Type) + } + if e.Stats == nil { + return nil, fmt.Errorf(`"runsc events -stat" succeeded but no stat was provided`) + } return e.Stats, nil } diff --git a/pkg/state/BUILD b/pkg/state/BUILD index 089b3bbef..92c51879b 100644 --- a/pkg/state/BUILD +++ b/pkg/state/BUILD @@ -4,19 +4,6 @@ load("//tools/go_generics:defs.bzl", "go_template_instance") package(licenses = ["notice"]) go_template_instance( - name = "pending_list", - out = "pending_list.go", - package = "state", - prefix = "pending", - template = "//pkg/ilist:generic_list", - types = { - "Element": "*objectEncodeState", - "ElementMapper": "pendingMapper", - "Linker": "*pendingEntry", - }, -) - -go_template_instance( name = "deferred_list", out = "deferred_list.go", package = "state", @@ -83,7 +70,6 @@ go_library( "deferred_list.go", "encode.go", "encode_unsafe.go", - "pending_list.go", "state.go", "state_norace.go", "state_race.go", diff --git a/pkg/state/decode.go b/pkg/state/decode.go index 89467ca8e..e519ddeca 100644 --- a/pkg/state/decode.go +++ b/pkg/state/decode.go @@ -21,6 +21,7 @@ import ( "math" "reflect" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/state/wire" ) @@ -258,7 +259,7 @@ func (ds *decodeState) waitObject(ods *objectDecodeState, encoded wire.Object, c // For the purposes of this function, a child object is either a field within a // struct or an array element, with one such indirection per element in // path. The returned value may be an unexported field, so it may not be -// directly assignable. See unsafePointerTo. +// directly assignable. See decode_unsafe.go. func walkChild(path []wire.Dot, obj reflect.Value) reflect.Value { // See wire.Ref.Dots. The path here is specified in reverse order. for i := len(path) - 1; i >= 0; i-- { @@ -519,9 +520,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e // Normal assignment: authoritative only if no dots. v := ds.register(x, obj.Type().Elem()) - if v.IsValid() { - obj.Set(unsafePointerTo(v)) - } + obj.Set(reflectValueRWAddr(v)) case wire.Bool: obj.SetBool(bool(x)) case wire.Int: @@ -559,7 +558,7 @@ func (ds *decodeState) decodeObject(ods *objectDecodeState, obj reflect.Value, e // contents will still be filled in later on. typ := reflect.ArrayOf(int(x.Capacity), obj.Type().Elem()) // The object type. v := ds.register(&x.Ref, typ) - obj.Set(v.Slice3(0, int(x.Length), int(x.Capacity))) + obj.Set(reflectValueRWSlice3(v, 0, int(x.Length), int(x.Capacity))) case *wire.Array: ds.decodeArray(ods, obj, x) case *wire.Struct: @@ -592,7 +591,7 @@ func (ds *decodeState) Load(obj reflect.Value) { ds.pending.PushBack(rootOds) // Read the number of objects. - lastID, object, err := ReadHeader(ds.r) + numObjects, object, err := ReadHeader(ds.r) if err != nil { Failf("header error: %w", err) } @@ -604,42 +603,44 @@ func (ds *decodeState) Load(obj reflect.Value) { var ( encoded wire.Object ods *objectDecodeState - id = objectID(1) + id objectID tid = typeID(1) ) if err := safely(func() { // Decode all objects in the stream. // - // Note that the structure of this decoding loop should match - // the raw decoding loop in printer.go. - for id <= objectID(lastID) { - // Unmarshal the object. + // Note that the structure of this decoding loop should match the raw + // decoding loop in state/pretty/pretty.printer.printStream(). + for i := uint64(0); i < numObjects; { + // Unmarshal either a type object or object ID. encoded = wire.Load(ds.r) - - // Is this a type object? Handle inline. - if wt, ok := encoded.(*wire.Type); ok { - ds.types.Register(wt) + switch we := encoded.(type) { + case *wire.Type: + ds.types.Register(we) tid++ encoded = nil continue + case wire.Uint: + id = objectID(we) + i++ + // Unmarshal and resolve the actual object. + encoded = wire.Load(ds.r) + ods = ds.lookup(id) + if ods != nil { + // Decode the object. + ds.decodeObject(ods, ods.obj, encoded) + } else { + // If an object hasn't had interest registered + // previously or isn't yet valid, we deferred + // decoding until interest is registered. + ds.deferred[id] = encoded + } + // For error handling. + ods = nil + encoded = nil + default: + Failf("wanted type or object ID, got %#v", encoded) } - - // Actually resolve the object. - ods = ds.lookup(id) - if ods != nil { - // Decode the object. - ds.decodeObject(ods, ods.obj, encoded) - } else { - // If an object hasn't had interest registered - // previously or isn't yet valid, we deferred - // decoding until interest is registered. - ds.deferred[id] = encoded - } - - // For error handling. - ods = nil - encoded = nil - id++ } }); err != nil { // Include as much information as we can, taking into account @@ -647,16 +648,25 @@ func (ds *decodeState) Load(obj reflect.Value) { if ods != nil { Failf("error decoding object ID %d (%T) from %#v: %w", id, ods.obj.Interface(), encoded, err) } else if encoded != nil { - Failf("lookup error decoding object ID %d from %#v: %w", id, encoded, err) + Failf("error decoding from %#v: %w", encoded, err) } else { Failf("general decoding error: %w", err) } } // Check if we have any deferred objects. + numDeferred := 0 for id, encoded := range ds.deferred { - // Shoud never happen, the graph was bogus. - Failf("still have deferred objects: one is ID %d, %#v", id, encoded) + numDeferred++ + if s, ok := encoded.(*wire.Struct); ok && s.TypeID != 0 { + typ := ds.types.LookupType(typeID(s.TypeID)) + log.Warningf("unused deferred object: ID %d, type %v", id, typ) + } else { + log.Warningf("unused deferred object: ID %d, %#v", id, encoded) + } + } + if numDeferred != 0 { + Failf("still had %d deferred objects", numDeferred) } // Scan and fire all callbacks. We iterate over the list of incomplete diff --git a/pkg/state/decode_unsafe.go b/pkg/state/decode_unsafe.go index d048f61a1..f1208e2a2 100644 --- a/pkg/state/decode_unsafe.go +++ b/pkg/state/decode_unsafe.go @@ -15,13 +15,62 @@ package state import ( + "fmt" "reflect" + "runtime" "unsafe" ) -// unsafePointerTo is logically equivalent to reflect.Value.Addr, but works on -// values representing unexported fields. This bypasses visibility, but not -// type safety. -func unsafePointerTo(obj reflect.Value) reflect.Value { +// reflectValueRWAddr is equivalent to obj.Addr(), except that the returned +// reflect.Value is usable in assignments even if obj was obtained by the use +// of unexported struct fields. +// +// Preconditions: obj.CanAddr(). +func reflectValueRWAddr(obj reflect.Value) reflect.Value { return reflect.NewAt(obj.Type(), unsafe.Pointer(obj.UnsafeAddr())) } + +// reflectValueRWSlice3 is equivalent to arr.Slice3(i, j, k), except that the +// returned reflect.Value is usable in assignments even if obj was obtained by +// the use of unexported struct fields. +// +// Preconditions: +// * arr.Kind() == reflect.Array. +// * i, j, k >= 0. +// * i <= j <= k <= arr.Len(). +func reflectValueRWSlice3(arr reflect.Value, i, j, k int) reflect.Value { + if arr.Kind() != reflect.Array { + panic(fmt.Sprintf("arr has kind %v, wanted %v", arr.Kind(), reflect.Array)) + } + if i < 0 || j < 0 || k < 0 { + panic(fmt.Sprintf("negative subscripts (%d, %d, %d)", i, j, k)) + } + if i > j { + panic(fmt.Sprintf("subscript i (%d) > j (%d)", i, j)) + } + if j > k { + panic(fmt.Sprintf("subscript j (%d) > k (%d)", j, k)) + } + if k > arr.Len() { + panic(fmt.Sprintf("subscript k (%d) > array length (%d)", k, arr.Len())) + } + + sliceTyp := reflect.SliceOf(arr.Type().Elem()) + if i == arr.Len() { + // By precondition, i == j == k == arr.Len(). + return reflect.MakeSlice(sliceTyp, 0, 0) + } + slh := reflect.SliceHeader{ + // reflect.Value.CanAddr() == false for arrays, so we need to get the + // address from the first element of the array. + Data: arr.Index(i).UnsafeAddr(), + Len: j - i, + Cap: k - i, + } + slobj := reflect.NewAt(sliceTyp, unsafe.Pointer(&slh)).Elem() + // Before slobj is constructed, arr holds the only pointer-typed pointer to + // the array since reflect.SliceHeader.Data is a uintptr, so arr must be + // kept alive. + runtime.KeepAlive(arr) + return slobj +} diff --git a/pkg/state/encode.go b/pkg/state/encode.go index 92fcad4e9..560e7c2a3 100644 --- a/pkg/state/encode.go +++ b/pkg/state/encode.go @@ -17,13 +17,14 @@ package state import ( "context" "reflect" + "sort" "gvisor.dev/gvisor/pkg/state/wire" ) // objectEncodeState the type and identity of an object occupying a memory // address range. This is the value type for addrSet, and the intrusive entry -// for the pending and deferred lists. +// for the deferred list. type objectEncodeState struct { // id is the assigned ID for this object. id objectID @@ -47,7 +48,6 @@ type objectEncodeState struct { // references may be updated directly and automatically. refs []*wire.Ref - pendingEntry deferredEntry } @@ -93,9 +93,15 @@ type encodeState struct { // serialized. pendingTypes []wire.Type - // pending is the list of objects to be serialized. Serialization does + // pending maps object IDs to objects to be serialized. Serialization does // not actually occur until the full object graph is computed. - pending pendingList + pending map[objectID]*objectEncodeState + + // encodedStructs maps reflect.Values representing structs to previous + // encodings of those structs. This is necessary to avoid duplicate calls + // to SaverLoader.StateSave() that may result in multiple calls to + // Sink.SaveValue() for a given field, resulting in object duplication. + encodedStructs map[reflect.Value]*wire.Struct // stats tracks time data. stats Stats @@ -189,7 +195,8 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { // depending on this value knows there's nothing there. return } - if seg, _ := es.values.Find(addr); seg.Ok() { + seg, gap := es.values.Find(addr) + if seg.Ok() { // Ensure the map types match. existing := seg.Value() if existing.obj.Type() != obj.Type() { @@ -203,13 +210,20 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { } // Record the map. + r := addrRange{addr, addr + 1} oes := &objectEncodeState{ id: es.nextID(), obj: obj, how: encodeMapAsValue, } - es.values.Add(addrRange{addr, addr + 1}, oes) - es.pending.PushBack(oes) + // Use Insert instead of InsertWithoutMergingUnchecked when race + // detection is enabled to get additional sanity-checking from Merge. + if !raceEnabled { + es.values.InsertWithoutMergingUnchecked(gap, r, oes) + } else { + es.values.Insert(gap, r, oes) + } + es.pending[oes.id] = oes es.deferred.PushBack(oes) // See above: no ref recording. @@ -245,7 +259,7 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { obj: obj, } es.zeroValues[typ] = oes - es.pending.PushBack(oes) + es.pending[oes.id] = oes es.deferred.PushBack(oes) } @@ -258,86 +272,112 @@ func (es *encodeState) resolve(obj reflect.Value, ref *wire.Ref) { size = 1 // See above. } - // Calculate the container. end := addr + size r := addrRange{addr, end} - if seg, _ := es.values.Find(addr); seg.Ok() { + seg := es.values.LowerBoundSegment(addr) + var ( + oes *objectEncodeState + gap addrGapIterator + ) + + // Does at least one previously-registered object overlap this one? + if seg.Ok() && seg.Start() < end { existing := seg.Value() - switch { - case seg.Start() == addr && seg.End() == end && obj.Type() == existing.obj.Type(): - // The object is a perfect match. Happy path. Avoid the - // traversal and just return directly. We don't need to - // encode the type information or any dots here. + + if seg.Range() == r && typ == existing.obj.Type() { + // This exact object is already registered. Avoid the traversal and + // just return directly. We don't need to encode the type + // information or any dots here. ref.Root = wire.Uint(existing.id) existing.refs = append(existing.refs, ref) return + } - case (seg.Start() < addr && seg.End() >= end) || (seg.Start() <= addr && seg.End() > end): - // The previously registered object is larger than - // this, no need to update. But we expect some - // traversal below. + if seg.Range().IsSupersetOf(r) && (seg.Range() != r || isSameSizeParent(existing.obj, typ)) { + // This object is contained within a previously-registered object. + // Perform traversal from the container to the new object. + ref.Root = wire.Uint(existing.id) + ref.Dots = traverse(existing.obj.Type(), typ, seg.Start(), addr) + ref.Type = es.findType(existing.obj.Type()) + existing.refs = append(existing.refs, ref) + return + } - case seg.Start() == addr && seg.End() == end: - if !isSameSizeParent(obj, existing.obj.Type()) { - break // Needs traversal. + // This object contains one or more previously-registered objects. + // Remove them and update existing references to use the new one. + oes := &objectEncodeState{ + // Reuse the root ID of the first contained element. + id: existing.id, + obj: obj, + } + type elementEncodeState struct { + addr uintptr + typ reflect.Type + refs []*wire.Ref + } + var ( + elems []elementEncodeState + gap addrGapIterator + ) + for { + // Each contained object should be completely contained within + // this one. + if raceEnabled && !r.IsSupersetOf(seg.Range()) { + Failf("containing object %#v does not contain existing object %#v", obj, existing.obj) } - fallthrough // Needs update. - - case (seg.Start() > addr && seg.End() <= end) || (seg.Start() >= addr && seg.End() < end): - // Update the object and redo the encoding. - old := existing.obj - existing.obj = obj + elems = append(elems, elementEncodeState{ + addr: seg.Start(), + typ: existing.obj.Type(), + refs: existing.refs, + }) + delete(es.pending, existing.id) es.deferred.Remove(existing) - es.deferred.PushBack(existing) - - // The previously registered object is superseded by - // this new object. We are guaranteed to not have any - // mergeable neighbours in this segment set. - if !raceEnabled { - seg.SetRangeUnchecked(r) - } else { - // Add extra paranoid. This will be statically - // removed at compile time unless a race build. - es.values.Remove(seg) - es.values.Add(r, existing) - seg = es.values.LowerBoundSegment(addr) + gap = es.values.Remove(seg) + seg = gap.NextSegment() + if !seg.Ok() || seg.Start() >= end { + break } - - // Compute the traversal required & update references. - dots := traverse(obj.Type(), old.Type(), addr, seg.Start()) - wt := es.findType(obj.Type()) - for _, ref := range existing.refs { + existing = seg.Value() + } + wt := es.findType(typ) + for _, elem := range elems { + dots := traverse(typ, elem.typ, addr, elem.addr) + for _, ref := range elem.refs { + ref.Root = wire.Uint(oes.id) ref.Dots = append(ref.Dots, dots...) ref.Type = wt } - default: - // There is a non-sensical overlap. - Failf("overlapping objects: [new object] %#v [existing object] %#v", obj, existing.obj) + oes.refs = append(oes.refs, elem.refs...) } - - // Compute the new reference, record and return it. - ref.Root = wire.Uint(existing.id) - ref.Dots = traverse(existing.obj.Type(), obj.Type(), seg.Start(), addr) - ref.Type = es.findType(obj.Type()) - existing.refs = append(existing.refs, ref) + // Finally register the new containing object. + if !raceEnabled { + es.values.InsertWithoutMergingUnchecked(gap, r, oes) + } else { + es.values.Insert(gap, r, oes) + } + es.pending[oes.id] = oes + es.deferred.PushBack(oes) + ref.Root = wire.Uint(oes.id) + oes.refs = append(oes.refs, ref) return } - // The only remaining case is a pointer value that doesn't overlap with - // any registered addresses. Create a new entry for it, and start - // tracking the first reference we just created. - oes := &objectEncodeState{ + // No existing object overlaps this one. Register a new object. + oes = &objectEncodeState{ id: es.nextID(), obj: obj, } + if seg.Ok() { + gap = seg.PrevGap() + } else { + gap = es.values.LastGap() + } if !raceEnabled { - es.values.AddWithoutMerging(r, oes) + es.values.InsertWithoutMergingUnchecked(gap, r, oes) } else { - // Merges should never happen. This is just enabled extra - // sanity checks because the Merge function below will panic. - es.values.Add(r, oes) + es.values.Insert(gap, r, oes) } - es.pending.PushBack(oes) + es.pending[oes.id] = oes es.deferred.PushBack(oes) ref.Root = wire.Uint(oes.id) oes.refs = append(oes.refs, ref) @@ -439,6 +479,14 @@ func (oe *objectEncoder) save(slot int, obj reflect.Value) { // encodeStruct encodes a composite object. func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) { + if s, ok := es.encodedStructs[obj]; ok { + *dest = s + return + } + s := &wire.Struct{} + *dest = s + es.encodedStructs[obj] = s + // Ensure that the obj is addressable. There are two cases when it is // not. First, is when this is dispatched via SaveValue. Second, when // this is a map key as a struct. Either way, we need to make a copy to @@ -449,10 +497,6 @@ func (es *encodeState) encodeStruct(obj reflect.Value, dest *wire.Object) { obj = localObj.Elem() } - // Prepare the value. - s := &wire.Struct{} - *dest = s - // Look the type up in the database. te, ok := es.types.Lookup(obj.Type()) if te == nil { @@ -730,45 +774,43 @@ func (es *encodeState) Save(obj reflect.Value) { Failf("encoding error at object %#v: %w", oes.obj.Interface(), err) } - // Check that items are pending. - if es.pending.Front() == nil { + // Check that we have objects to serialize. + if len(es.pending) == 0 { Failf("pending is empty?") } - // Write the header with the number of objects. Note that there is no - // way that es.lastID could conflict with objectID, which would - // indicate that an impossibly large encoding. - if err := WriteHeader(es.w, uint64(es.lastID), true); err != nil { + // Write the header with the number of objects. + if err := WriteHeader(es.w, uint64(len(es.pending)), true); err != nil { Failf("error writing header: %w", err) } // Serialize all pending types and pending objects. Note that we don't // bother removing from this list as we walk it because that just // wastes time. It will not change after this point. - var id objectID if err := safely(func() { for _, wt := range es.pendingTypes { // Encode the type. wire.Save(es.w, &wt) } - for oes = es.pending.Front(); oes != nil; oes = oes.pendingEntry.Next() { - id++ // First object is 1. - if oes.id != id { - Failf("expected id %d, got %d", id, oes.id) - } - - // Marshall the object. + // Emit objects in ID order. + ids := make([]objectID, 0, len(es.pending)) + for id := range es.pending { + ids = append(ids, id) + } + sort.Slice(ids, func(i, j int) bool { + return ids[i] < ids[j] + }) + for _, id := range ids { + // Encode the id. + wire.Save(es.w, wire.Uint(id)) + // Marshal the object. + oes := es.pending[id] wire.Save(es.w, oes.encoded) } }); err != nil { // Include the object and the error. Failf("error serializing object %#v: %w", oes.encoded, err) } - - // Check what we wrote. - if id != es.lastID { - Failf("expected %d objects, wrote %d", es.lastID, id) - } } // objectFlag indicates that the length is a # of objects, rather than a raw @@ -797,11 +839,6 @@ func WriteHeader(w wire.Writer, length uint64, object bool) error { }) } -// pendingMapper is for the pending list. -type pendingMapper struct{} - -func (pendingMapper) linkerFor(oes *objectEncodeState) *pendingEntry { return &oes.pendingEntry } - // deferredMapper is for the deferred list. type deferredMapper struct{} diff --git a/pkg/state/pretty/pretty.go b/pkg/state/pretty/pretty.go index 887f453a9..c6e8bb31d 100644 --- a/pkg/state/pretty/pretty.go +++ b/pkg/state/pretty/pretty.go @@ -42,6 +42,7 @@ func (p *printer) formatRef(x *wire.Ref, graph uint64) string { buf.WriteString(typ) buf.WriteString(")(") buf.WriteString(baseRef) + buf.WriteString(")") for _, component := range x.Dots { switch v := component.(type) { case *wire.FieldName: @@ -53,7 +54,6 @@ func (p *printer) formatRef(x *wire.Ref, graph uint64) string { panic(fmt.Sprintf("unreachable: switch should be exhaustive, unhandled case %v", reflect.TypeOf(component))) } } - buf.WriteString(")") fullRef = buf.String() } if p.html { @@ -242,19 +242,22 @@ func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) { // Note that this loop must match the general structure of the // loop in decode.go. But we don't register type information, // etc. and just print the raw structures. + type objectAndID struct { + id uint64 + obj wire.Object + } var ( tid uint64 = 1 - objects []wire.Object + objects []objectAndID ) - for oid := uint64(1); oid <= length; { - // Unmarshal the object. + for i := uint64(0); i < length; { + // Unmarshal either a type object or object ID. encoded := wire.Load(r) - - // Is this a type? - if typ, ok := encoded.(*wire.Type); ok { + switch we := encoded.(type) { + case *wire.Type: str, _ := p.format(graph, 0, encoded) tag := fmt.Sprintf("g%dt%d", graph, tid) - p.typeSpecs[tag] = typ + p.typeSpecs[tag] = we if p.html { // See below. tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) @@ -263,20 +266,22 @@ func (p *printer) printStream(w io.Writer, r wire.Reader) (err error) { return err } tid++ - continue + case wire.Uint: + // Unmarshal the actual object. + objects = append(objects, objectAndID{ + id: uint64(we), + obj: wire.Load(r), + }) + i++ + default: + return fmt.Errorf("wanted type or object ID, got %#v", encoded) } - - // Otherwise, it is a node. - objects = append(objects, encoded) - oid++ } - for i, encoded := range objects { - // oid starts at 1. - oid := i + 1 + for _, objAndID := range objects { // Format the node. - str, _ := p.format(graph, 0, encoded) - tag := fmt.Sprintf("g%dr%d", graph, oid) + str, _ := p.format(graph, 0, objAndID.obj) + tag := fmt.Sprintf("g%dr%d", graph, objAndID.id) if p.html { // Create a little tag with an anchor next to it for linking. tag = fmt.Sprintf("<a name=\"%s\">%s</a><a href=\"#%s\">⚓</a>", tag, tag, tag) diff --git a/pkg/state/state.go b/pkg/state/state.go index acb629969..6b8540f03 100644 --- a/pkg/state/state.go +++ b/pkg/state/state.go @@ -90,10 +90,12 @@ func (e *ErrState) Unwrap() error { func Save(ctx context.Context, w wire.Writer, rootPtr interface{}) (Stats, error) { // Create the encoding state. es := encodeState{ - ctx: ctx, - w: w, - types: makeTypeEncodeDatabase(), - zeroValues: make(map[reflect.Type]*objectEncodeState), + ctx: ctx, + w: w, + types: makeTypeEncodeDatabase(), + zeroValues: make(map[reflect.Type]*objectEncodeState), + pending: make(map[objectID]*objectEncodeState), + encodedStructs: make(map[reflect.Value]*wire.Struct), } // Perform the encoding. diff --git a/pkg/state/tests/struct.go b/pkg/state/tests/struct.go index bd2c2b399..69143d194 100644 --- a/pkg/state/tests/struct.go +++ b/pkg/state/tests/struct.go @@ -54,12 +54,47 @@ type outerArray struct { } // +stateify savable +type outerSlice struct { + inner []inner +} + +// +stateify savable type inner struct { v int64 } // +stateify savable +type outerFieldValue struct { + inner innerFieldValue +} + +// +stateify savable +type innerFieldValue struct { + v int64 `state:".(*savedFieldValue)"` +} + +// +stateify savable +type savedFieldValue struct { + v int64 +} + +func (ifv *innerFieldValue) saveV() *savedFieldValue { + return &savedFieldValue{ifv.v} +} + +func (ifv *innerFieldValue) loadV(sfv *savedFieldValue) { + ifv.v = sfv.v +} + +// +stateify savable type system struct { v1 interface{} v2 interface{} } + +// +stateify savable +type system3 struct { + v1 interface{} + v2 interface{} + v3 interface{} +} diff --git a/pkg/state/tests/struct_test.go b/pkg/state/tests/struct_test.go index de9d17aa7..c91c2c032 100644 --- a/pkg/state/tests/struct_test.go +++ b/pkg/state/tests/struct_test.go @@ -15,6 +15,7 @@ package tests import ( + "math/rand" "testing" "gvisor.dev/gvisor/pkg/state" @@ -67,12 +68,23 @@ func TestRegisterTypeOnlyStruct(t *testing.T) { } func TestEmbeddedPointers(t *testing.T) { - var ( - ofs outerSame - of1 outerFieldFirst - of2 outerFieldSecond - oa outerArray - ) + // Give each int64 a random value to prevent Go from using + // runtime.staticuint64s, which confounds tests for struct duplication. + magic := func() int64 { + for { + n := rand.Int63() + if n < 0 || n > 255 { + return n + } + } + } + + ofs := outerSame{inner{magic()}} + of1 := outerFieldFirst{inner{magic()}, magic()} + of2 := outerFieldSecond{magic(), inner{magic()}} + oa := outerArray{[2]inner{{magic()}, {magic()}}} + osl := outerSlice{oa.inner[:]} + ofv := outerFieldValue{innerFieldValue{magic()}} runTestCases(t, false, "embedded-pointers", []interface{}{ system{&ofs, &ofs.inner}, @@ -85,5 +97,15 @@ func TestEmbeddedPointers(t *testing.T) { system{&oa, &oa.inner[1]}, system{&oa.inner[0], &oa}, system{&oa.inner[1], &oa}, + system3{&oa, &oa.inner[0], &oa.inner[1]}, + system3{&oa, &oa.inner[1], &oa.inner[0]}, + system3{&oa.inner[0], &oa, &oa.inner[1]}, + system3{&oa.inner[1], &oa, &oa.inner[0]}, + system3{&oa.inner[0], &oa.inner[1], &oa}, + system3{&oa.inner[1], &oa.inner[0], &oa}, + system{&oa, &osl}, + system{&osl, &oa}, + system{&ofv, &ofv.inner}, + system{&ofv.inner, &ofv}, }) } diff --git a/pkg/tcpip/adapters/gonet/gonet_test.go b/pkg/tcpip/adapters/gonet/gonet_test.go index 12b061def..b196324c7 100644 --- a/pkg/tcpip/adapters/gonet/gonet_test.go +++ b/pkg/tcpip/adapters/gonet/gonet_test.go @@ -97,6 +97,9 @@ type testConnection struct { func connect(s *stack.Stack, addr tcpip.FullAddress) (*testConnection, *tcpip.Error) { wq := &waiter.Queue{} ep, err := s.NewEndpoint(tcp.ProtocolNumber, ipv4.ProtocolNumber, wq) + if err != nil { + return nil, err + } entry, ch := waiter.NewChannelEntry(nil) wq.EventRegister(&entry, waiter.EventOut) @@ -145,7 +148,9 @@ func TestCloseReader(t *testing.T) { defer close(done) c, err := l.Accept() if err != nil { - t.Fatalf("l.Accept() = %v", err) + t.Errorf("l.Accept() = %v", err) + // Cannot call Fatalf in goroutine. Just return from the goroutine. + return } // Give c.Read() a chance to block before closing the connection. @@ -416,7 +421,9 @@ func TestDeadlineChange(t *testing.T) { defer close(done) c, err := l.Accept() if err != nil { - t.Fatalf("l.Accept() = %v", err) + t.Errorf("l.Accept() = %v", err) + // Cannot call Fatalf in goroutine. Just return from the goroutine. + return } c.SetDeadline(time.Now().Add(time.Minute)) diff --git a/pkg/tcpip/checker/checker.go b/pkg/tcpip/checker/checker.go index 6f81b0164..530f2ae2f 100644 --- a/pkg/tcpip/checker/checker.go +++ b/pkg/tcpip/checker/checker.go @@ -205,7 +205,7 @@ func IPv4Options(want []byte) NetworkChecker { if !ok { t.Fatalf("unexpected network header passed to checker, got = %T, want = header.IPv4", h[0]) } - options := ip.Options() + options := []byte(ip.Options()) // cmp.Diff does not consider nil slices equal to empty slices, but we do. if len(want) == 0 && len(options) == 0 { return @@ -859,6 +859,21 @@ func ICMPv4Seq(want uint16) TransportChecker { } } +// ICMPv4Pointer creates a checker that checks the ICMPv4 Param Problem pointer. +func ICMPv4Pointer(want uint8) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv4, ok := h.(header.ICMPv4) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv4", h) + } + if got := icmpv4.Pointer(); got != want { + t.Fatalf("unexpected ICMP Param Problem pointer, got = %d, want = %d", got, want) + } + } +} + // ICMPv4Checksum creates a checker that checks the ICMPv4 Checksum. // This assumes that the payload exactly makes up the rest of the slice. func ICMPv4Checksum() TransportChecker { @@ -953,6 +968,38 @@ func ICMPv6Code(want header.ICMPv6Code) TransportChecker { } } +// ICMPv6TypeSpecific creates a checker that checks the ICMPv6 TypeSpecific +// field. +func ICMPv6TypeSpecific(want uint32) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) + } + if got := icmpv6.TypeSpecific(); got != want { + t.Fatalf("unexpected ICMP TypeSpecific, got = %d, want = %d", got, want) + } + } +} + +// ICMPv6Payload creates a checker that checks the payload in an ICMPv6 packet. +func ICMPv6Payload(want []byte) TransportChecker { + return func(t *testing.T, h header.Transport) { + t.Helper() + + icmpv6, ok := h.(header.ICMPv6) + if !ok { + t.Fatalf("unexpected transport header passed to checker, got = %T, want = header.ICMPv6", h) + } + payload := icmpv6.Payload() + if diff := cmp.Diff(want, payload); diff != "" { + t.Errorf("ICMP payload mismatch (-want +got):\n%s", diff) + } + } +} + // NDP creates a checker that checks that the packet contains a valid NDP // message for type of ty, with potentially additional checks specified by // checkers. diff --git a/pkg/tcpip/header/icmpv4.go b/pkg/tcpip/header/icmpv4.go index 504408878..2f13dea6a 100644 --- a/pkg/tcpip/header/icmpv4.go +++ b/pkg/tcpip/header/icmpv4.go @@ -99,7 +99,8 @@ const ( // ICMP codes for ICMPv4 Time Exceeded messages as defined in RFC 792. const ( - ICMPv4TTLExceeded ICMPv4Code = 0 + ICMPv4TTLExceeded ICMPv4Code = 0 + ICMPv4ReassemblyTimeout ICMPv4Code = 1 ) // ICMP codes for ICMPv4 Destination Unreachable messages as defined in RFC 792. @@ -126,6 +127,12 @@ func (b ICMPv4) Code() ICMPv4Code { return ICMPv4Code(b[1]) } // SetCode sets the ICMP code field. func (b ICMPv4) SetCode(c ICMPv4Code) { b[1] = byte(c) } +// Pointer returns the pointer field in a Parameter Problem packet. +func (b ICMPv4) Pointer() byte { return b[icmpv4PointerOffset] } + +// SetPointer sets the pointer field in a Parameter Problem packet. +func (b ICMPv4) SetPointer(c byte) { b[icmpv4PointerOffset] = c } + // Checksum is the ICMP checksum field. func (b ICMPv4) Checksum() uint16 { return binary.BigEndian.Uint16(b[icmpv4ChecksumOffset:]) diff --git a/pkg/tcpip/header/ipv4.go b/pkg/tcpip/header/ipv4.go index 4c6e4be64..961b77628 100644 --- a/pkg/tcpip/header/ipv4.go +++ b/pkg/tcpip/header/ipv4.go @@ -16,6 +16,7 @@ package header import ( "encoding/binary" + "errors" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -38,7 +39,6 @@ import ( // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Options | Padding | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// const ( versIHL = 0 tos = 1 @@ -93,7 +93,7 @@ type IPv4Fields struct { DstAddr tcpip.Address } -// IPv4 represents an ipv4 header stored in a byte array. +// IPv4 is an IPv4 header. // Most of the methods of IPv4 access to the underlying slice without // checking the boundaries and could panic because of 'index out of range'. // Always call IsValid() to validate an instance of IPv4 before using other @@ -106,10 +106,13 @@ const ( IPv4MinimumSize = 20 // IPv4MaximumHeaderSize is the maximum size of an IPv4 header. Given - // that there are only 4 bits to represents the header length in 32-bit - // units, the header cannot exceed 15*4 = 60 bytes. + // that there are only 4 bits (max 0xF (15)) to represent the header length + // in 32-bit (4 byte) units, the header cannot exceed 15*4 = 60 bytes. IPv4MaximumHeaderSize = 60 + // IPv4MaximumOptionsSize is the largest size the IPv4 options can be. + IPv4MaximumOptionsSize = IPv4MaximumHeaderSize - IPv4MinimumSize + // IPv4MaximumPayloadSize is the maximum size of a valid IPv4 payload. // // Linux limits this to 65,515 octets (the max IP datagram size - the IPv4 @@ -130,7 +133,7 @@ const ( // IPv4ProtocolNumber is IPv4's network protocol number. IPv4ProtocolNumber tcpip.NetworkProtocolNumber = 0x0800 - // IPv4Version is the version of the ipv4 protocol. + // IPv4Version is the version of the IPv4 protocol. IPv4Version = 4 // IPv4AllSystems is the all systems IPv4 multicast address as per @@ -148,6 +151,13 @@ const ( // packet that every IPv4 capable host must be able to // process/reassemble. IPv4MinimumProcessableDatagramSize = 576 + + // IPv4MinimumMTU is the minimum MTU required by IPv4, per RFC 791, + // section 3.2: + // Every internet module must be able to forward a datagram of 68 octets + // without further fragmentation. This is because an internet header may be + // up to 60 octets, and the minimum fragment is 8 octets. + IPv4MinimumMTU = 68 ) // Flags that may be set in an IPv4 packet. @@ -191,14 +201,13 @@ func IPVersion(b []byte) int { // Internet Header Length is the length of the internet header in 32 // bit words, and thus points to the beginning of the data. Note that // the minimum value for a correct header is 5. -// const ( ipVersionShift = 4 ipIHLMask = 0x0f IPv4IHLStride = 4 ) -// HeaderLength returns the value of the "header length" field of the ipv4 +// HeaderLength returns the value of the "header length" field of the IPv4 // header. The length returned is in bytes. func (b IPv4) HeaderLength() uint8 { return (b[versIHL] & ipIHLMask) * IPv4IHLStride @@ -212,17 +221,17 @@ func (b IPv4) SetHeaderLength(hdrLen uint8) { b[versIHL] = (IPv4Version << ipVersionShift) | ((hdrLen / IPv4IHLStride) & ipIHLMask) } -// ID returns the value of the identifier field of the ipv4 header. +// ID returns the value of the identifier field of the IPv4 header. func (b IPv4) ID() uint16 { return binary.BigEndian.Uint16(b[id:]) } -// Protocol returns the value of the protocol field of the ipv4 header. +// Protocol returns the value of the protocol field of the IPv4 header. func (b IPv4) Protocol() uint8 { return b[protocol] } -// Flags returns the "flags" field of the ipv4 header. +// Flags returns the "flags" field of the IPv4 header. func (b IPv4) Flags() uint8 { return uint8(binary.BigEndian.Uint16(b[flagsFO:]) >> 13) } @@ -232,41 +241,44 @@ func (b IPv4) More() bool { return b.Flags()&IPv4FlagMoreFragments != 0 } -// TTL returns the "TTL" field of the ipv4 header. +// TTL returns the "TTL" field of the IPv4 header. func (b IPv4) TTL() uint8 { return b[ttl] } -// FragmentOffset returns the "fragment offset" field of the ipv4 header. +// FragmentOffset returns the "fragment offset" field of the IPv4 header. func (b IPv4) FragmentOffset() uint16 { return binary.BigEndian.Uint16(b[flagsFO:]) << 3 } -// TotalLength returns the "total length" field of the ipv4 header. +// TotalLength returns the "total length" field of the IPv4 header. func (b IPv4) TotalLength() uint16 { return binary.BigEndian.Uint16(b[IPv4TotalLenOffset:]) } -// Checksum returns the checksum field of the ipv4 header. +// Checksum returns the checksum field of the IPv4 header. func (b IPv4) Checksum() uint16 { return binary.BigEndian.Uint16(b[checksum:]) } -// SourceAddress returns the "source address" field of the ipv4 header. +// SourceAddress returns the "source address" field of the IPv4 header. func (b IPv4) SourceAddress() tcpip.Address { return tcpip.Address(b[srcAddr : srcAddr+IPv4AddressSize]) } -// DestinationAddress returns the "destination address" field of the ipv4 +// DestinationAddress returns the "destination address" field of the IPv4 // header. func (b IPv4) DestinationAddress() tcpip.Address { return tcpip.Address(b[dstAddr : dstAddr+IPv4AddressSize]) } -// Options returns a a buffer holding the options. -func (b IPv4) Options() []byte { +// IPv4Options is a buffer that holds all the raw IP options. +type IPv4Options []byte + +// Options returns a buffer holding the options. +func (b IPv4) Options() IPv4Options { hdrLen := b.HeaderLength() - return b[options:hdrLen:hdrLen] + return IPv4Options(b[options:hdrLen:hdrLen]) } // TransportProtocol implements Network.TransportProtocol. @@ -279,17 +291,17 @@ func (b IPv4) Payload() []byte { return b[b.HeaderLength():][:b.PayloadLength()] } -// PayloadLength returns the length of the payload portion of the ipv4 packet. +// PayloadLength returns the length of the payload portion of the IPv4 packet. func (b IPv4) PayloadLength() uint16 { return b.TotalLength() - uint16(b.HeaderLength()) } -// TOS returns the "type of service" field of the ipv4 header. +// TOS returns the "type of service" field of the IPv4 header. func (b IPv4) TOS() (uint8, uint32) { return b[tos], 0 } -// SetTOS sets the "type of service" field of the ipv4 header. +// SetTOS sets the "type of service" field of the IPv4 header. func (b IPv4) SetTOS(v uint8, _ uint32) { b[tos] = v } @@ -299,18 +311,18 @@ func (b IPv4) SetTTL(v byte) { b[ttl] = v } -// SetTotalLength sets the "total length" field of the ipv4 header. +// SetTotalLength sets the "total length" field of the IPv4 header. func (b IPv4) SetTotalLength(totalLength uint16) { binary.BigEndian.PutUint16(b[IPv4TotalLenOffset:], totalLength) } -// SetChecksum sets the checksum field of the ipv4 header. +// SetChecksum sets the checksum field of the IPv4 header. func (b IPv4) SetChecksum(v uint16) { binary.BigEndian.PutUint16(b[checksum:], v) } // SetFlagsFragmentOffset sets the "flags" and "fragment offset" fields of the -// ipv4 header. +// IPv4 header. func (b IPv4) SetFlagsFragmentOffset(flags uint8, offset uint16) { v := (uint16(flags) << 13) | (offset >> 3) binary.BigEndian.PutUint16(b[flagsFO:], v) @@ -321,23 +333,23 @@ func (b IPv4) SetID(v uint16) { binary.BigEndian.PutUint16(b[id:], v) } -// SetSourceAddress sets the "source address" field of the ipv4 header. +// SetSourceAddress sets the "source address" field of the IPv4 header. func (b IPv4) SetSourceAddress(addr tcpip.Address) { copy(b[srcAddr:srcAddr+IPv4AddressSize], addr) } -// SetDestinationAddress sets the "destination address" field of the ipv4 +// SetDestinationAddress sets the "destination address" field of the IPv4 // header. func (b IPv4) SetDestinationAddress(addr tcpip.Address) { copy(b[dstAddr:dstAddr+IPv4AddressSize], addr) } -// CalculateChecksum calculates the checksum of the ipv4 header. +// CalculateChecksum calculates the checksum of the IPv4 header. func (b IPv4) CalculateChecksum() uint16 { return Checksum(b[:b.HeaderLength()], 0) } -// Encode encodes all the fields of the ipv4 header. +// Encode encodes all the fields of the IPv4 header. func (b IPv4) Encode(i *IPv4Fields) { b.SetHeaderLength(i.IHL) b[tos] = i.TOS @@ -351,7 +363,7 @@ func (b IPv4) Encode(i *IPv4Fields) { copy(b[dstAddr:dstAddr+IPv4AddressSize], i.DstAddr) } -// EncodePartial updates the total length and checksum fields of ipv4 header, +// EncodePartial updates the total length and checksum fields of IPv4 header, // taking in the partial checksum, which is the checksum of the header without // the total length and checksum fields. It is useful in cases when similar // packets are produced. @@ -398,3 +410,424 @@ func IsV4LoopbackAddress(addr tcpip.Address) bool { } return addr[0] == 0x7f } + +// ========================= Options ========================== + +// An IPv4OptionType can hold the valuse for the Type in an IPv4 option. +type IPv4OptionType byte + +// These constants are needed to identify individual options in the option list. +// While RFC 791 (page 31) says "Every internet module must be able to act on +// every option." This has not generally been adhered to and some options have +// very low rates of support. We do not support options other than those shown +// below. + +const ( + // IPv4OptionListEndType is the option type for the End Of Option List + // option. Anything following is ignored. + IPv4OptionListEndType IPv4OptionType = 0 + + // IPv4OptionNOPType is the No-Operation option. May appear between other + // options and may appear multiple times. + IPv4OptionNOPType IPv4OptionType = 1 + + // IPv4OptionRecordRouteType is used by each router on the path of the packet + // to record its path. It is carried over to an Echo Reply. + IPv4OptionRecordRouteType IPv4OptionType = 7 + + // IPv4OptionTimestampType is the option type for the Timestamp option. + IPv4OptionTimestampType IPv4OptionType = 68 + + // ipv4OptionTypeOffset is the offset in an option of its type field. + ipv4OptionTypeOffset = 0 + + // IPv4OptionLengthOffset is the offset in an option of its length field. + IPv4OptionLengthOffset = 1 +) + +// Potential errors when parsing generic IP options. +var ( + ErrIPv4OptZeroLength = errors.New("zero length IP option") + ErrIPv4OptDuplicate = errors.New("duplicate IP option") + ErrIPv4OptInvalid = errors.New("invalid IP option") + ErrIPv4OptMalformed = errors.New("malformed IP option") + ErrIPv4OptionTruncated = errors.New("truncated IP option") + ErrIPv4OptionAddress = errors.New("bad IP option address") +) + +// IPv4Option is an interface representing various option types. +type IPv4Option interface { + // Type returns the type identifier of the option. + Type() IPv4OptionType + + // Size returns the size of the option in bytes. + Size() uint8 + + // Contents returns a slice holding the contents of the option. + Contents() []byte +} + +var _ IPv4Option = (*IPv4OptionGeneric)(nil) + +// IPv4OptionGeneric is an IPv4 Option of unknown type. +type IPv4OptionGeneric []byte + +// Type implements IPv4Option. +func (o *IPv4OptionGeneric) Type() IPv4OptionType { + return IPv4OptionType((*o)[ipv4OptionTypeOffset]) +} + +// Size implements IPv4Option. +func (o *IPv4OptionGeneric) Size() uint8 { return uint8(len(*o)) } + +// Contents implements IPv4Option. +func (o *IPv4OptionGeneric) Contents() []byte { return []byte(*o) } + +// IPv4OptionIterator is an iterator pointing to a specific IP option +// at any point of time. It also holds information as to a new options buffer +// that we are building up to hand back to the caller. +type IPv4OptionIterator struct { + options IPv4Options + // ErrCursor is where we are while parsing options. It is exported as any + // resulting ICMP packet is supposed to have a pointer to the byte within + // the IP packet where the error was detected. + ErrCursor uint8 + nextErrCursor uint8 + newOptions [IPv4MaximumOptionsSize]byte + writePoint int +} + +// MakeIterator sets up and returns an iterator of options. It also sets up the +// building of a new option set. +func (o IPv4Options) MakeIterator() IPv4OptionIterator { + return IPv4OptionIterator{ + options: o, + nextErrCursor: IPv4MinimumSize, + } +} + +// RemainingBuffer returns the remaining (unused) part of the new option buffer, +// into which a new option may be written. +func (i *IPv4OptionIterator) RemainingBuffer() IPv4Options { + return IPv4Options(i.newOptions[i.writePoint:]) +} + +// ConsumeBuffer marks a portion of the new buffer as used. +func (i *IPv4OptionIterator) ConsumeBuffer(size int) { + i.writePoint += size +} + +// PushNOPOrEnd puts one of the single byte options onto the new options. +// Only values 0 or 1 (ListEnd or NOP) are valid input. +func (i *IPv4OptionIterator) PushNOPOrEnd(val IPv4OptionType) { + if val > IPv4OptionNOPType { + panic(fmt.Sprintf("invalid option type %d pushed onto option build buffer", val)) + } + i.newOptions[i.writePoint] = byte(val) + i.writePoint++ +} + +// Finalize returns the completed replacement options buffer padded +// as needed. +func (i *IPv4OptionIterator) Finalize() IPv4Options { + // RFC 791 page 31 says: + // The options might not end on a 32-bit boundary. The internet header + // must be filled out with octets of zeros. The first of these would + // be interpreted as the end-of-options option, and the remainder as + // internet header padding. + // Since the buffer is already zero filled we just need to step the write + // pointer up to the next multiple of 4. + options := IPv4Options(i.newOptions[:(i.writePoint+0x3) & ^0x3]) + // Poison the write pointer. + i.writePoint = len(i.newOptions) + return options +} + +// Next returns the next IP option in the buffer/list of IP options. +// It returns +// - A slice of bytes holding the next option or nil if there is error. +// - A boolean which is true if parsing of all the options is complete. +// - An error which is non-nil if an error condition was encountered. +func (i *IPv4OptionIterator) Next() (IPv4Option, bool, error) { + // The opts slice gets shorter as we process the options. When we have no + // bytes left we are done. + if len(i.options) == 0 { + return nil, true, nil + } + + i.ErrCursor = i.nextErrCursor + + optType := IPv4OptionType(i.options[ipv4OptionTypeOffset]) + + if optType == IPv4OptionNOPType || optType == IPv4OptionListEndType { + optionBody := i.options[:1] + i.options = i.options[1:] + i.nextErrCursor = i.ErrCursor + 1 + retval := IPv4OptionGeneric(optionBody) + return &retval, false, nil + } + + // There are no more single byte options defined. All the rest have a length + // field so we need to sanity check it. + if len(i.options) == 1 { + return nil, true, ErrIPv4OptMalformed + } + + optLen := i.options[IPv4OptionLengthOffset] + + if optLen == 0 { + i.ErrCursor++ + return nil, true, ErrIPv4OptZeroLength + } + + if optLen == 1 { + i.ErrCursor++ + return nil, true, ErrIPv4OptMalformed + } + + if optLen > uint8(len(i.options)) { + i.ErrCursor++ + return nil, true, ErrIPv4OptionTruncated + } + + optionBody := i.options[:optLen] + i.nextErrCursor = i.ErrCursor + optLen + i.options = i.options[optLen:] + + // Check the length of some option types that we know. + switch optType { + case IPv4OptionTimestampType: + if optLen < IPv4OptionTimestampHdrLength { + i.ErrCursor++ + return nil, true, ErrIPv4OptMalformed + } + retval := IPv4OptionTimestamp(optionBody) + return &retval, false, nil + + case IPv4OptionRecordRouteType: + if optLen < IPv4OptionRecordRouteHdrLength { + i.ErrCursor++ + return nil, true, ErrIPv4OptMalformed + } + retval := IPv4OptionRecordRoute(optionBody) + return &retval, false, nil + } + retval := IPv4OptionGeneric(optionBody) + return &retval, false, nil +} + +// +// IP Timestamp option - RFC 791 page 22. +// +--------+--------+--------+--------+ +// |01000100| length | pointer|oflw|flg| +// +--------+--------+--------+--------+ +// | internet address | +// +--------+--------+--------+--------+ +// | timestamp | +// +--------+--------+--------+--------+ +// | ... | +// +// Type = 68 +// +// The Option Length is the number of octets in the option counting +// the type, length, pointer, and overflow/flag octets (maximum +// length 40). +// +// The Pointer is the number of octets from the beginning of this +// option to the end of timestamps plus one (i.e., it points to the +// octet beginning the space for next timestamp). The smallest +// legal value is 5. The timestamp area is full when the pointer +// is greater than the length. +// +// The Overflow (oflw) [4 bits] is the number of IP modules that +// cannot register timestamps due to lack of space. +// +// The Flag (flg) [4 bits] values are +// +// 0 -- time stamps only, stored in consecutive 32-bit words, +// +// 1 -- each timestamp is preceded with internet address of the +// registering entity, +// +// 3 -- the internet address fields are prespecified. An IP +// module only registers its timestamp if it matches its own +// address with the next specified internet address. +// +// Timestamps are defined in RFC 791 page 22 as milliseconds since midnight UTC. +// +// The Timestamp is a right-justified, 32-bit timestamp in +// milliseconds since midnight UT. If the time is not available in +// milliseconds or cannot be provided with respect to midnight UT +// then any time may be inserted as a timestamp provided the high +// order bit of the timestamp field is set to one to indicate the +// use of a non-standard value. + +// IPv4OptTSFlags sefines the values expected in the Timestamp +// option Flags field. +type IPv4OptTSFlags uint8 + +// +// Timestamp option specific related constants. +const ( + // IPv4OptionTimestampHdrLength is the length of the timestamp option header. + IPv4OptionTimestampHdrLength = 4 + + // IPv4OptionTimestampSize is the size of an IP timestamp. + IPv4OptionTimestampSize = 4 + + // IPv4OptionTimestampWithAddrSize is the size of an IP timestamp + Address. + IPv4OptionTimestampWithAddrSize = IPv4AddressSize + IPv4OptionTimestampSize + + // IPv4OptionTimestampMaxSize is limited by space for options + IPv4OptionTimestampMaxSize = IPv4MaximumOptionsSize + + // IPv4OptionTimestampOnlyFlag is a flag indicating that only timestamp + // is present. + IPv4OptionTimestampOnlyFlag IPv4OptTSFlags = 0 + + // IPv4OptionTimestampWithIPFlag is a flag indicating that both timestamps and + // IP are present. + IPv4OptionTimestampWithIPFlag IPv4OptTSFlags = 1 + + // IPv4OptionTimestampWithPredefinedIPFlag is a flag indicating that + // predefined IP is present. + IPv4OptionTimestampWithPredefinedIPFlag IPv4OptTSFlags = 3 +) + +// ipv4TimestampTime provides the current time as specified in RFC 791. +func ipv4TimestampTime(clock tcpip.Clock) uint32 { + const millisecondsPerDay = 24 * 3600 * 1000 + const nanoPerMilli = 1000000 + return uint32((clock.NowNanoseconds() / nanoPerMilli) % millisecondsPerDay) +} + +// IP Timestamp option fields. +const ( + // IPv4OptTSPointerOffset is the offset of the Timestamp pointer field. + IPv4OptTSPointerOffset = 2 + + // IPv4OptTSPointerOffset is the offset of the combined Flag and Overflow + // fields, (each being 4 bits). + IPv4OptTSOFLWAndFLGOffset = 3 + // These constants define the sub byte fields of the Flag and OverFlow field. + ipv4OptionTimestampOverflowshift = 4 + ipv4OptionTimestampFlagsMask byte = 0x0f +) + +var _ IPv4Option = (*IPv4OptionTimestamp)(nil) + +// IPv4OptionTimestamp is a Timestamp option from RFC 791. +type IPv4OptionTimestamp []byte + +// Type implements IPv4Option.Type(). +func (ts *IPv4OptionTimestamp) Type() IPv4OptionType { return IPv4OptionTimestampType } + +// Size implements IPv4Option. +func (ts *IPv4OptionTimestamp) Size() uint8 { return uint8(len(*ts)) } + +// Contents implements IPv4Option. +func (ts *IPv4OptionTimestamp) Contents() []byte { return []byte(*ts) } + +// Pointer returns the pointer field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Pointer() uint8 { + return (*ts)[IPv4OptTSPointerOffset] +} + +// Flags returns the flags field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Flags() IPv4OptTSFlags { + return IPv4OptTSFlags((*ts)[IPv4OptTSOFLWAndFLGOffset] & ipv4OptionTimestampFlagsMask) +} + +// Overflow returns the Overflow field in the IP Timestamp option. +func (ts *IPv4OptionTimestamp) Overflow() uint8 { + return (*ts)[IPv4OptTSOFLWAndFLGOffset] >> ipv4OptionTimestampOverflowshift +} + +// IncOverflow increments the Overflow field in the IP Timestamp option. It +// returns the incremented value. If the return value is 0 then the field +// overflowed. +func (ts *IPv4OptionTimestamp) IncOverflow() uint8 { + (*ts)[IPv4OptTSOFLWAndFLGOffset] += 1 << ipv4OptionTimestampOverflowshift + return ts.Overflow() +} + +// UpdateTimestamp updates the fields of the next free timestamp slot. +func (ts *IPv4OptionTimestamp) UpdateTimestamp(addr tcpip.Address, clock tcpip.Clock) { + slot := (*ts)[ts.Pointer()-1:] + + switch ts.Flags() { + case IPv4OptionTimestampOnlyFlag: + binary.BigEndian.PutUint32(slot, ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampSize + case IPv4OptionTimestampWithIPFlag: + if n := copy(slot, addr); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize)) + } + binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize + case IPv4OptionTimestampWithPredefinedIPFlag: + if tcpip.Address(slot[:IPv4AddressSize]) == addr { + binary.BigEndian.PutUint32(slot[IPv4AddressSize:], ipv4TimestampTime(clock)) + (*ts)[IPv4OptTSPointerOffset] += IPv4OptionTimestampWithAddrSize + } + } +} + +// RecordRoute option specific related constants. +// +// from RFC 791 page 20: +// Record Route +// +// +--------+--------+--------+---------//--------+ +// |00000111| length | pointer| route data | +// +--------+--------+--------+---------//--------+ +// Type=7 +// +// The record route option provides a means to record the route of +// an internet datagram. +// +// The option begins with the option type code. The second octet +// is the option length which includes the option type code and the +// length octet, the pointer octet, and length-3 octets of route +// data. The third octet is the pointer into the route data +// indicating the octet which begins the next area to store a route +// address. The pointer is relative to this option, and the +// smallest legal value for the pointer is 4. +const ( + // IPv4OptionRecordRouteHdrLength is the length of the Record Route option + // header. + IPv4OptionRecordRouteHdrLength = 3 + + // IPv4OptRRPointerOffset is the offset to the pointer field in an RR + // option, which points to the next free slot in the list of addresses. + IPv4OptRRPointerOffset = 2 +) + +var _ IPv4Option = (*IPv4OptionRecordRoute)(nil) + +// IPv4OptionRecordRoute is an IPv4 RecordRoute option defined by RFC 791. +type IPv4OptionRecordRoute []byte + +// Pointer returns the pointer field in the IP RecordRoute option. +func (rr *IPv4OptionRecordRoute) Pointer() uint8 { + return (*rr)[IPv4OptRRPointerOffset] +} + +// StoreAddress stores the given IPv4 address into the next free slot. +func (rr *IPv4OptionRecordRoute) StoreAddress(addr tcpip.Address) { + start := rr.Pointer() - 1 // A one based number. + // start and room checked by caller. + if n := copy((*rr)[start:], addr); n != IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, IPv4AddressSize)) + } + (*rr)[IPv4OptRRPointerOffset] += IPv4AddressSize +} + +// Type implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Type() IPv4OptionType { return IPv4OptionRecordRouteType } + +// Size implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Size() uint8 { return uint8(len(*rr)) } + +// Contents implements IPv4Option. +func (rr *IPv4OptionRecordRoute) Contents() []byte { return []byte(*rr) } diff --git a/pkg/tcpip/header/ipv6.go b/pkg/tcpip/header/ipv6.go index c5d8a3456..4e7e5f76a 100644 --- a/pkg/tcpip/header/ipv6.go +++ b/pkg/tcpip/header/ipv6.go @@ -101,8 +101,10 @@ const ( // The address is ff02::2. IPv6AllRoutersMulticastAddress tcpip.Address = "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" - // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 2460, - // section 5. + // IPv6MinimumMTU is the minimum MTU required by IPv6, per RFC 8200, + // section 5: + // IPv6 requires that every link in the Internet have an MTU of 1280 octets + // or greater. This is known as the IPv6 minimum link MTU. IPv6MinimumMTU = 1280 // IPv6Loopback is the IPv6 Loopback address. @@ -373,6 +375,12 @@ func IsV6LinkLocalAddress(addr tcpip.Address) bool { return addr[0] == 0xfe && (addr[1]&0xc0) == 0x80 } +// IsV6LoopbackAddress determines if the provided address is an IPv6 loopback +// address. +func IsV6LoopbackAddress(addr tcpip.Address) bool { + return addr == IPv6Loopback +} + // IsV6LinkLocalMulticastAddress determines if the provided address is an IPv6 // link-local multicast address. func IsV6LinkLocalMulticastAddress(addr tcpip.Address) bool { diff --git a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go index dc239a0d0..2777f1411 100644 --- a/pkg/tcpip/link/sharedmem/pipe/pipe_test.go +++ b/pkg/tcpip/link/sharedmem/pipe/pipe_test.go @@ -470,6 +470,7 @@ func TestConcurrentReaderWriter(t *testing.T) { const count = 1000000 var wg sync.WaitGroup + defer wg.Wait() wg.Add(1) go func() { defer wg.Done() @@ -489,30 +490,23 @@ func TestConcurrentReaderWriter(t *testing.T) { } }() - wg.Add(1) - go func() { - defer wg.Done() - runtime.Gosched() - for i := 0; i < count; i++ { - n := 1 + rr.Intn(80) - rb := rx.Pull() - for rb == nil { - rb = rx.Pull() - } + for i := 0; i < count; i++ { + n := 1 + rr.Intn(80) + rb := rx.Pull() + for rb == nil { + rb = rx.Pull() + } - if n != len(rb) { - t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) - } + if n != len(rb) { + t.Fatalf("Bad %v-th buffer length: got %v, want %v", i, len(rb), n) + } - for j := range rb { - if v := byte(rr.Intn(256)); v != rb[j] { - t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) - } + for j := range rb { + if v := byte(rr.Intn(256)); v != rb[j] { + t.Fatalf("Bad %v-th read buffer at index %v: got %v, want %v", i, j, rb[j], v) } - - rx.Flush() } - }() - wg.Wait() + rx.Flush() + } } diff --git a/pkg/tcpip/link/sniffer/sniffer.go b/pkg/tcpip/link/sniffer/sniffer.go index 560477926..b3e8c4b92 100644 --- a/pkg/tcpip/link/sniffer/sniffer.go +++ b/pkg/tcpip/link/sniffer/sniffer.go @@ -205,7 +205,12 @@ func logPacket(prefix string, protocol tcpip.NetworkProtocolNumber, pkt *stack.P // // We don't clone the original packet buffer so that the new packet buffer // does not have any of its headers set. - pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views())}) + // + // We trim the link headers from the cloned buffer as the sniffer doesn't + // handle link headers. + vv := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + vv.TrimFront(len(pkt.LinkHeader().View())) + pkt = stack.NewPacketBuffer(stack.PacketBufferOptions{Data: vv}) switch protocol { case header.IPv4ProtocolNumber: if ok := parse.IPv4(pkt); !ok { diff --git a/pkg/tcpip/link/tun/BUILD b/pkg/tcpip/link/tun/BUILD index 0243424f6..86f14db76 100644 --- a/pkg/tcpip/link/tun/BUILD +++ b/pkg/tcpip/link/tun/BUILD @@ -8,7 +8,7 @@ go_template_instance( out = "tun_endpoint_refs.go", package = "tun", prefix = "tunEndpoint", - template = "//pkg/refs_vfs2:refs_template", + template = "//pkg/refsvfs2:refs_template", types = { "T": "tunEndpoint", }, @@ -28,6 +28,7 @@ go_library( "//pkg/context", "//pkg/log", "//pkg/refs", + "//pkg/refsvfs2", "//pkg/sync", "//pkg/syserror", "//pkg/tcpip", diff --git a/pkg/tcpip/link/tun/device.go b/pkg/tcpip/link/tun/device.go index f94491026..cda6328a2 100644 --- a/pkg/tcpip/link/tun/device.go +++ b/pkg/tcpip/link/tun/device.go @@ -150,7 +150,6 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE // 2. Creating a new NIC. id := tcpip.NICID(s.UniqueID()) - // TODO(gvisor.dev/1486): enable leak check for tunEndpoint. endpoint := &tunEndpoint{ Endpoint: channel.New(defaultDevOutQueueLen, defaultDevMtu, ""), stack: s, @@ -158,6 +157,7 @@ func attachOrCreateNIC(s *stack.Stack, name, prefix string, linkCaps stack.LinkE name: name, isTap: prefix == "tap", } + endpoint.EnableLeakCheck() endpoint.Endpoint.LinkEPCapabilities = linkCaps if endpoint.name == "" { endpoint.name = fmt.Sprintf("%s%d", prefix, id) diff --git a/pkg/tcpip/network/arp/BUILD b/pkg/tcpip/network/arp/BUILD index b40dde96b..8a6bcfc2c 100644 --- a/pkg/tcpip/network/arp/BUILD +++ b/pkg/tcpip/network/arp/BUILD @@ -30,5 +30,6 @@ go_test( "//pkg/tcpip/stack", "//pkg/tcpip/transport/icmp", "@com_github_google_go_cmp//cmp:go_default_library", + "@com_github_google_go_cmp//cmp/cmpopts:go_default_library", ], ) diff --git a/pkg/tcpip/network/arp/arp.go b/pkg/tcpip/network/arp/arp.go index 7df77c66e..33a4a0720 100644 --- a/pkg/tcpip/network/arp/arp.go +++ b/pkg/tcpip/network/arp/arp.go @@ -18,6 +18,7 @@ package arp import ( + "fmt" "sync/atomic" "gvisor.dev/gvisor/pkg/tcpip" @@ -121,7 +122,7 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu return tcpip.ErrNotSupported } -func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !e.isEnabled() { return } @@ -144,34 +145,43 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { linkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) e.linkAddrCache.AddLinkAddress(e.nic.ID(), addr, linkAddr) } else { - if r.Stack().CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), header.IPv4ProtocolNumber, localAddr) == 0 { return // we have no useful answer, ignore the request } remoteAddr := tcpip.Address(h.ProtocolAddressSender()) remoteLinkAddr := tcpip.LinkAddress(h.HardwareAddressSender()) - e.nud.HandleProbe(remoteAddr, localAddr, ProtocolNumber, remoteLinkAddr, e.protocol) + e.nud.HandleProbe(remoteAddr, ProtocolNumber, remoteLinkAddr, e.protocol) } - // As per RFC 826, under Packet Reception: - // Swap hardware and protocol fields, putting the local hardware and - // protocol addresses in the sender fields. - // - // Send the packet to the (new) target hardware address on the same - // hardware on which the request was received. - origSender := h.HardwareAddressSender() - r.RemoteLinkAddress = tcpip.LinkAddress(origSender) respPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ ReserveHeaderBytes: int(e.nic.MaxHeaderLength()) + header.ARPSize, }) packet := header.ARP(respPkt.NetworkHeader().Push(header.ARPSize)) + respPkt.NetworkProtocolNumber = ProtocolNumber packet.SetIPv4OverEthernet() packet.SetOp(header.ARPReply) - copy(packet.HardwareAddressSender(), r.LocalLinkAddress[:]) - copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()) - copy(packet.HardwareAddressTarget(), origSender) - copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()) - _ = e.nic.WritePacket(r, nil /* gso */, ProtocolNumber, respPkt) + // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a + // link address. + _ = copy(packet.HardwareAddressSender(), e.nic.LinkAddress()) + if n := copy(packet.ProtocolAddressSender(), h.ProtocolAddressTarget()); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + origSender := h.HardwareAddressSender() + if n := copy(packet.HardwareAddressTarget(), origSender); n != header.EthernetAddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.EthernetAddressSize)) + } + if n := copy(packet.ProtocolAddressTarget(), h.ProtocolAddressSender()); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + + // As per RFC 826, under Packet Reception: + // Swap hardware and protocol fields, putting the local hardware and + // protocol addresses in the sender fields. + // + // Send the packet to the (new) target hardware address on the same + // hardware on which the request was received. + _ = e.nic.WritePacketToRemote(tcpip.LinkAddress(origSender), nil /* gso */, ProtocolNumber, respPkt) case header.ARPReply: addr := tcpip.Address(h.ProtocolAddressSender()) @@ -199,6 +209,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // protocol implements stack.NetworkProtocol and stack.LinkAddressResolver. type protocol struct { + stack *stack.Stack } func (p *protocol) Number() tcpip.NetworkProtocolNumber { return ProtocolNumber } @@ -227,26 +238,44 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver.LinkAddressRequest. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { - r := &stack.Route{ - NetProto: ProtocolNumber, - RemoteLinkAddress: remoteLinkAddr, +func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { + if len(remoteLinkAddr) == 0 { + remoteLinkAddr = header.EthernetBroadcastAddress } - if len(r.RemoteLinkAddress) == 0 { - r.RemoteLinkAddress = header.EthernetBroadcastAddress + + nicID := nic.ID() + if len(localAddr) == 0 { + addr, err := p.stack.GetMainNICAddress(nicID, header.IPv4ProtocolNumber) + if err != nil { + return err + } + + if len(addr.Address) == 0 { + return tcpip.ErrNetworkUnreachable + } + + localAddr = addr.Address + } else if p.stack.CheckLocalAddress(nicID, header.IPv4ProtocolNumber, localAddr) == 0 { + return tcpip.ErrBadLocalAddress } pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.ARPSize, + ReserveHeaderBytes: int(nic.MaxHeaderLength()) + header.ARPSize, }) h := header.ARP(pkt.NetworkHeader().Push(header.ARPSize)) + pkt.NetworkProtocolNumber = ProtocolNumber h.SetIPv4OverEthernet() h.SetOp(header.ARPRequest) - copy(h.HardwareAddressSender(), linkEP.LinkAddress()) - copy(h.ProtocolAddressSender(), localAddr) - copy(h.ProtocolAddressTarget(), addr) - - return linkEP.WritePacket(r, nil /* gso */, ProtocolNumber, pkt) + // TODO(gvisor.dev/issue/4582): check copied length once TAP devices have a + // link address. + _ = copy(h.HardwareAddressSender(), nic.LinkAddress()) + if n := copy(h.ProtocolAddressSender(), localAddr); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + if n := copy(h.ProtocolAddressTarget(), targetAddr); n != header.IPv4AddressSize { + panic(fmt.Sprintf("copied %d bytes, expected %d bytes", n, header.IPv4AddressSize)) + } + return nic.WritePacketToRemote(remoteLinkAddr, nil /* gso */, ProtocolNumber, pkt) } // ResolveStaticAddress implements stack.LinkAddressResolver.ResolveStaticAddress. @@ -286,6 +315,6 @@ func (*protocol) Parse(pkt *stack.PacketBuffer) (proto tcpip.TransportProtocolNu // Note, to make sure that the ARP endpoint receives ARP packets, the "arp" // address must be added to every NIC that should respond to ARP requests. See // ProtocolAddress for more details. -func NewProtocol(*stack.Stack) stack.NetworkProtocol { - return &protocol{} +func NewProtocol(s *stack.Stack) stack.NetworkProtocol { + return &protocol{stack: s} } diff --git a/pkg/tcpip/network/arp/arp_test.go b/pkg/tcpip/network/arp/arp_test.go index 626af975a..087ee9c66 100644 --- a/pkg/tcpip/network/arp/arp_test.go +++ b/pkg/tcpip/network/arp/arp_test.go @@ -22,6 +22,7 @@ import ( "time" "github.com/google/go-cmp/cmp" + "github.com/google/go-cmp/cmp/cmpopts" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -78,13 +79,11 @@ func (t eventType) String() string { type eventInfo struct { eventType eventType nicID tcpip.NICID - addr tcpip.Address - linkAddr tcpip.LinkAddress - state stack.NeighborState + entry stack.NeighborEntry } func (e eventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.eventType, e.nicID, e.addr, e.linkAddr, e.state) + return fmt.Sprintf("%s event for NIC #%d, %#v", e.eventType, e.nicID, e.entry) } // arpDispatcher implements NUDDispatcher to validate the dispatching of @@ -96,35 +95,29 @@ type arpDispatcher struct { var _ stack.NUDDispatcher = (*arpDispatcher)(nil) -func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { +func (d *arpDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry stack.NeighborEntry) { e := eventInfo{ eventType: entryAdded, nicID: nicID, - addr: addr, - linkAddr: linkAddr, - state: state, + entry: entry, } d.C <- e } -func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { +func (d *arpDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry stack.NeighborEntry) { e := eventInfo{ eventType: entryChanged, nicID: nicID, - addr: addr, - linkAddr: linkAddr, - state: state, + entry: entry, } d.C <- e } -func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state stack.NeighborState, updatedAt time.Time) { +func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry stack.NeighborEntry) { e := eventInfo{ eventType: entryRemoved, nicID: nicID, - addr: addr, - linkAddr: linkAddr, - state: state, + entry: entry, } d.C <- e } @@ -132,7 +125,7 @@ func (d *arpDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, func (d *arpDispatcher) waitForEvent(ctx context.Context, want eventInfo) error { select { case got := <-d.C: - if diff := cmp.Diff(got, want, cmp.AllowUnexported(got)); diff != "" { + if diff := cmp.Diff(got, want, cmp.AllowUnexported(got), cmpopts.IgnoreFields(stack.NeighborEntry{}, "UpdatedAtNanos")); diff != "" { return fmt.Errorf("got invalid event (-got +want):\n%s", diff) } case <-ctx.Done(): @@ -373,9 +366,11 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { wantEvent := eventInfo{ eventType: entryAdded, nicID: nicID, - addr: test.senderAddr, - linkAddr: tcpip.LinkAddress(test.senderLinkAddr), - state: stack.Stale, + entry: stack.NeighborEntry{ + Addr: test.senderAddr, + LinkAddr: tcpip.LinkAddress(test.senderLinkAddr), + State: stack.Stale, + }, } if err := c.nudDisp.waitForEventWithTimeout(wantEvent, time.Second); err != nil { t.Fatal(err) @@ -404,9 +399,6 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { if got, want := neigh.LinkAddr, test.senderLinkAddr; got != want { t.Errorf("got neighbor LinkAddr = %s, want = %s", got, want) } - if got, want := neigh.LocalAddr, stackAddr; got != want { - t.Errorf("got neighbor LocalAddr = %s, want = %s", got, want) - } if got, want := neigh.State, stack.Stale; got != want { t.Errorf("got neighbor State = %s, want = %s", got, want) } @@ -423,43 +415,164 @@ func TestDirectRequestWithNeighborCache(t *testing.T) { } } +var _ stack.NetworkInterface = (*testInterface)(nil) + +type testInterface struct { + stack.LinkEndpoint + + nicID tcpip.NICID +} + +func (t *testInterface) ID() tcpip.NICID { + return t.nicID +} + +func (*testInterface) IsLoopback() bool { + return false +} + +func (*testInterface) Name() string { + return "" +} + +func (*testInterface) Enabled() bool { + return true +} + +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + r := stack.Route{ + NetProto: protocol, + RemoteLinkAddress: remoteLinkAddr, + } + return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) +} + func TestLinkAddressRequest(t *testing.T) { + const nicID = 1 + + testAddr := tcpip.Address([]byte{1, 2, 3, 4}) + tests := []struct { name string + nicAddr tcpip.Address + localAddr tcpip.Address remoteLinkAddr tcpip.LinkAddress - expectLinkAddr tcpip.LinkAddress + + expectedErr *tcpip.Error + expectedLocalAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress }{ { - name: "Unicast", + name: "Unicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + }, + { + name: "Multicast", + nicAddr: stackAddr, + localAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "Unicast with unspecified source", + nicAddr: stackAddr, + remoteLinkAddr: remoteLinkAddr, + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: remoteLinkAddr, + }, + { + name: "Multicast with unspecified source", + nicAddr: stackAddr, + remoteLinkAddr: "", + expectedLocalAddr: stackAddr, + expectedRemoteLinkAddr: header.EthernetBroadcastAddress, + }, + { + name: "Unicast with unassigned address", + localAddr: testAddr, + remoteLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrBadLocalAddress, + }, + { + name: "Multicast with unassigned address", + localAddr: testAddr, + remoteLinkAddr: "", + expectedErr: tcpip.ErrBadLocalAddress, + }, + { + name: "Unicast with no local address available", remoteLinkAddr: remoteLinkAddr, - expectLinkAddr: remoteLinkAddr, + expectedErr: tcpip.ErrNetworkUnreachable, }, { - name: "Multicast", + name: "Multicast with no local address available", remoteLinkAddr: "", - expectLinkAddr: header.EthernetBroadcastAddress, + expectedErr: tcpip.ErrNetworkUnreachable, }, } for _, test := range tests { - p := arp.NewProtocol(nil) - linkRes, ok := p.(stack.LinkAddressResolver) - if !ok { - t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") - } + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, + }) + p := s.NetworkProtocolInstance(arp.ProtocolNumber) + linkRes, ok := p.(stack.LinkAddressResolver) + if !ok { + t.Fatal("expected ARP protocol to implement stack.LinkAddressResolver") + } - linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) - if err := linkRes.LinkAddressRequest(stackAddr, remoteAddr, test.remoteLinkAddr, linkEP); err != nil { - t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", stackAddr, remoteAddr, test.remoteLinkAddr, err) - } + linkEP := channel.New(defaultChannelSize, defaultMTU, stackLinkAddr) + if err := s.CreateNIC(nicID, linkEP); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } - pkt, ok := linkEP.Read() - if !ok { - t.Fatal("expected to send a link address request") - } + if len(test.nicAddr) != 0 { + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ipv4.ProtocolNumber, test.nicAddr, err) + } + } - if got, want := pkt.Route.RemoteLinkAddress, test.expectLinkAddr; got != want { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", got, want) - } + // We pass a test network interface to LinkAddressRequest with the same + // NIC ID and link endpoint used by the NIC we created earlier so that we + // can mock a link address request and observe the packets sent to the + // link endpoint even though the stack uses the real NIC to validate the + // local address. + if err := linkRes.LinkAddressRequest(remoteAddr, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { + t.Fatalf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", remoteAddr, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + } + + if test.expectedErr != nil { + return + } + + pkt, ok := linkEP.Read() + if !ok { + t.Fatal("expected to send a link address request") + } + + if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) + } + + rep := header.ARP(stack.PayloadSince(pkt.Pkt.NetworkHeader())) + if got := tcpip.LinkAddress(rep.HardwareAddressSender()); got != stackLinkAddr { + t.Errorf("got HardwareAddressSender = %s, want = %s", got, stackLinkAddr) + } + if got := tcpip.Address(rep.ProtocolAddressSender()); got != test.expectedLocalAddr { + t.Errorf("got ProtocolAddressSender = %s, want = %s", got, test.expectedLocalAddr) + } + if got, want := tcpip.LinkAddress(rep.HardwareAddressTarget()), tcpip.LinkAddress("\x00\x00\x00\x00\x00\x00"); got != want { + t.Errorf("got HardwareAddressTarget = %s, want = %s", got, want) + } + if got := tcpip.Address(rep.ProtocolAddressTarget()); got != remoteAddr { + t.Errorf("got ProtocolAddressTarget = %s, want = %s", got, remoteAddr) + } + }) } } diff --git a/pkg/tcpip/network/fragmentation/fragmentation.go b/pkg/tcpip/network/fragmentation/fragmentation.go index ed502a473..936601287 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation.go +++ b/pkg/tcpip/network/fragmentation/fragmentation.go @@ -136,8 +136,16 @@ func NewFragmentation(blockSize uint16, highMemoryLimit, lowMemoryLimit int, rea // proto is the protocol number marked in the fragment being processed. It has // to be given here outside of the FragmentID struct because IPv6 should not use // the protocol to identify a fragment. +// +// releaseCB is a callback that will run when the fragment reassembly of a +// packet is complete or cancelled. releaseCB take a a boolean argument which is +// true iff the reassembly is cancelled due to timeout. releaseCB should be +// passed only with the first fragment of a packet. If more than one releaseCB +// are passed for the same packet, only the first releaseCB will be saved for +// the packet and the succeeding ones will be dropped by running them +// immediately with a false argument. func (f *Fragmentation) Process( - id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView) ( + id FragmentID, first, last uint16, more bool, proto uint8, vv buffer.VectorisedView, releaseCB func(bool)) ( buffer.VectorisedView, uint8, bool, error) { if first > last { return buffer.VectorisedView{}, 0, false, fmt.Errorf("first=%d is greater than last=%d: %w", first, last, ErrInvalidArgs) @@ -171,6 +179,12 @@ func (f *Fragmentation) Process( f.releaseReassemblersLocked() } } + if releaseCB != nil { + if !r.setCallback(releaseCB) { + // We got a duplicate callback. Release it immediately. + releaseCB(false /* timedOut */) + } + } f.mu.Unlock() res, firstFragmentProto, done, consumed, err := r.process(first, last, more, proto, vv) @@ -178,14 +192,14 @@ func (f *Fragmentation) Process( // We probably got an invalid sequence of fragments. Just // discard the reassembler and move on. f.mu.Lock() - f.release(r) + f.release(r, false /* timedOut */) f.mu.Unlock() return buffer.VectorisedView{}, 0, false, fmt.Errorf("fragmentation processing error: %w", err) } f.mu.Lock() f.size += consumed if done { - f.release(r) + f.release(r, false /* timedOut */) } // Evict reassemblers if we are consuming more memory than highLimit until // we reach lowLimit. @@ -195,14 +209,14 @@ func (f *Fragmentation) Process( if tail == nil { break } - f.release(tail) + f.release(tail, false /* timedOut */) } } f.mu.Unlock() return res, firstFragmentProto, done, nil } -func (f *Fragmentation) release(r *reassembler) { +func (f *Fragmentation) release(r *reassembler, timedOut bool) { // Before releasing a fragment we need to check if r is already marked as done. // Otherwise, we would delete it twice. if r.checkDoneOrMark() { @@ -216,6 +230,8 @@ func (f *Fragmentation) release(r *reassembler) { log.Printf("memory counter < 0 (%d), this is an accounting bug that requires investigation", f.size) f.size = 0 } + + r.release(timedOut) // releaseCB may run. } // releaseReassemblersLocked releases already-expired reassemblers, then @@ -238,31 +254,31 @@ func (f *Fragmentation) releaseReassemblersLocked() { break } // If the oldest reassembler has already expired, release it. - f.release(r) + f.release(r, true /* timedOut*/) } } // PacketFragmenter is the book-keeping struct for packet fragmentation. type PacketFragmenter struct { - transportHeader buffer.View - data buffer.VectorisedView - reserve int - innerMTU int - fragmentCount int - currentFragment int - fragmentOffset int + transportHeader buffer.View + data buffer.VectorisedView + reserve int + fragmentPayloadLen int + fragmentCount int + currentFragment int + fragmentOffset int } // MakePacketFragmenter prepares the struct needed for packet fragmentation. // // pkt is the packet to be fragmented. // -// innerMTU is the maximum number of bytes of fragmentable data a fragment can +// fragmentPayloadLen is the maximum number of bytes of fragmentable data a fragment can // have. // // reserve is the number of bytes that should be reserved for the headers in // each generated fragment. -func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) PacketFragmenter { +func MakePacketFragmenter(pkt *stack.PacketBuffer, fragmentPayloadLen uint32, reserve int) PacketFragmenter { // As per RFC 8200 Section 4.5, some IPv6 extension headers should not be // repeated in each fragment. However we do not currently support any header // of that kind yet, so the following computation is valid for both IPv4 and @@ -273,13 +289,13 @@ func MakePacketFragmenter(pkt *stack.PacketBuffer, innerMTU int, reserve int) Pa var fragmentableData buffer.VectorisedView fragmentableData.AppendView(pkt.TransportHeader().View()) fragmentableData.Append(pkt.Data) - fragmentCount := (fragmentableData.Size() + innerMTU - 1) / innerMTU + fragmentCount := (uint32(fragmentableData.Size()) + fragmentPayloadLen - 1) / fragmentPayloadLen return PacketFragmenter{ - data: fragmentableData, - reserve: reserve, - innerMTU: innerMTU, - fragmentCount: fragmentCount, + data: fragmentableData, + reserve: reserve, + fragmentPayloadLen: int(fragmentPayloadLen), + fragmentCount: int(fragmentCount), } } @@ -302,7 +318,7 @@ func (pf *PacketFragmenter) BuildNextFragment() (*stack.PacketBuffer, int, int, }) // Copy data for the fragment. - copied := pf.data.ReadToVV(&fragPkt.Data, pf.innerMTU) + copied := pf.data.ReadToVV(&fragPkt.Data, pf.fragmentPayloadLen) offset := pf.fragmentOffset pf.fragmentOffset += copied diff --git a/pkg/tcpip/network/fragmentation/fragmentation_test.go b/pkg/tcpip/network/fragmentation/fragmentation_test.go index d3c7d7f92..5dcd10730 100644 --- a/pkg/tcpip/network/fragmentation/fragmentation_test.go +++ b/pkg/tcpip/network/fragmentation/fragmentation_test.go @@ -105,7 +105,7 @@ func TestFragmentationProcess(t *testing.T) { f := NewFragmentation(minBlockSize, 1024, 512, reassembleTimeout, &faketime.NullClock{}) firstFragmentProto := c.in[0].proto for i, in := range c.in { - vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv) + vv, proto, done, err := f.Process(in.id, in.first, in.last, in.more, in.proto, in.vv, nil) if err != nil { t.Fatalf("f.Process(%+v, %d, %d, %t, %d, %X) failed: %s", in.id, in.first, in.last, in.more, in.proto, in.vv.ToView(), err) @@ -240,7 +240,7 @@ func TestReassemblingTimeout(t *testing.T) { for _, event := range test.events { clock.Advance(event.clockAdvance) if frag := event.fragment; frag != nil { - _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data)) + _, _, done, err := f.Process(FragmentID{}, frag.first, frag.last, frag.more, protocol, vv(len(frag.data), frag.data), nil) if err != nil { t.Fatalf("%s: f.Process failed: %s", event.name, err) } @@ -259,15 +259,15 @@ func TestReassemblingTimeout(t *testing.T) { func TestMemoryLimits(t *testing.T) { f := NewFragmentation(minBlockSize, 3, 1, reassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. - f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0")) + f.Process(FragmentID{ID: 0}, 0, 0, true, 0xFF, vv(1, "0"), nil) // Send first fragment with id = 1. - f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1")) + f.Process(FragmentID{ID: 1}, 0, 0, true, 0xFF, vv(1, "1"), nil) // Send first fragment with id = 2. - f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2")) + f.Process(FragmentID{ID: 2}, 0, 0, true, 0xFF, vv(1, "2"), nil) // Send first fragment with id = 3. This should caused id = 0 and id = 1 to be // evicted. - f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3")) + f.Process(FragmentID{ID: 3}, 0, 0, true, 0xFF, vv(1, "3"), nil) if _, ok := f.reassemblers[FragmentID{ID: 0}]; ok { t.Errorf("Memory limits are not respected: id=0 has not been evicted.") @@ -283,9 +283,9 @@ func TestMemoryLimits(t *testing.T) { func TestMemoryLimitsIgnoresDuplicates(t *testing.T) { f := NewFragmentation(minBlockSize, 1, 0, reassembleTimeout, &faketime.NullClock{}) // Send first fragment with id = 0. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil) // Send the same packet again. - f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0")) + f.Process(FragmentID{}, 0, 0, true, 0xFF, vv(1, "0"), nil) got := f.size want := 1 @@ -377,7 +377,7 @@ func TestErrors(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { f := NewFragmentation(test.blockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) - _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data)) + _, _, done, err := f.Process(FragmentID{}, test.first, test.last, test.more, 0, vv(len(test.data), test.data), nil) if !errors.Is(err, test.err) { t.Errorf("got Process(_, %d, %d, %t, _, %q) = (_, _, _, %v), want = (_, _, _, %v)", test.first, test.last, test.more, test.data, err, test.err) } @@ -403,14 +403,14 @@ func TestPacketFragmenter(t *testing.T) { tests := []struct { name string - innerMTU int + fragmentPayloadLen uint32 transportHeaderLen int payloadSize int wantFragments []fragmentInfo }{ { name: "Packet exactly fits in MTU", - innerMTU: 1280, + fragmentPayloadLen: 1280, transportHeaderLen: 0, payloadSize: 1280, wantFragments: []fragmentInfo{ @@ -419,7 +419,7 @@ func TestPacketFragmenter(t *testing.T) { }, { name: "Packet exactly does not fit in MTU", - innerMTU: 1000, + fragmentPayloadLen: 1000, transportHeaderLen: 0, payloadSize: 1001, wantFragments: []fragmentInfo{ @@ -429,7 +429,7 @@ func TestPacketFragmenter(t *testing.T) { }, { name: "Packet has a transport header", - innerMTU: 560, + fragmentPayloadLen: 560, transportHeaderLen: 40, payloadSize: 560, wantFragments: []fragmentInfo{ @@ -439,7 +439,7 @@ func TestPacketFragmenter(t *testing.T) { }, { name: "Packet has a huge transport header", - innerMTU: 500, + fragmentPayloadLen: 500, transportHeaderLen: 1300, payloadSize: 500, wantFragments: []fragmentInfo{ @@ -458,7 +458,7 @@ func TestPacketFragmenter(t *testing.T) { originalPayload.AppendView(pkt.TransportHeader().View()) originalPayload.Append(pkt.Data) var reassembledPayload buffer.VectorisedView - pf := MakePacketFragmenter(pkt, test.innerMTU, reserve) + pf := MakePacketFragmenter(pkt, test.fragmentPayloadLen, reserve) for i := 0; ; i++ { fragPkt, offset, copied, more := pf.BuildNextFragment() wantFragment := test.wantFragments[i] @@ -474,8 +474,8 @@ func TestPacketFragmenter(t *testing.T) { if more != wantFragment.more { t.Errorf("(fragment #%d) got more = %t, want = %t", i, more, wantFragment.more) } - if got := fragPkt.Size(); got > test.innerMTU { - t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.innerMTU) + if got := uint32(fragPkt.Size()); got > test.fragmentPayloadLen { + t.Errorf("(fragment #%d) got fragPkt.Size() = %d, want <= %d", i, got, test.fragmentPayloadLen) } if got := fragPkt.AvailableHeaderBytes(); got != reserve { t.Errorf("(fragment #%d) got fragPkt.AvailableHeaderBytes() = %d, want = %d", i, got, reserve) @@ -497,3 +497,89 @@ func TestPacketFragmenter(t *testing.T) { }) } } + +func TestReleaseCallback(t *testing.T) { + const ( + proto = 99 + ) + + var result int + var callbackReasonIsTimeout bool + cb1 := func(timedOut bool) { result = 1; callbackReasonIsTimeout = timedOut } + cb2 := func(timedOut bool) { result = 2; callbackReasonIsTimeout = timedOut } + + tests := []struct { + name string + callbacks []func(bool) + timeout bool + wantResult int + wantCallbackReasonIsTimeout bool + }{ + { + name: "callback runs on release", + callbacks: []func(bool){cb1}, + timeout: false, + wantResult: 1, + wantCallbackReasonIsTimeout: false, + }, + { + name: "first callback is nil", + callbacks: []func(bool){nil, cb2}, + timeout: false, + wantResult: 2, + wantCallbackReasonIsTimeout: false, + }, + { + name: "two callbacks - first one is set", + callbacks: []func(bool){cb1, cb2}, + timeout: false, + wantResult: 1, + wantCallbackReasonIsTimeout: false, + }, + { + name: "callback runs on timeout", + callbacks: []func(bool){cb1}, + timeout: true, + wantResult: 1, + wantCallbackReasonIsTimeout: true, + }, + { + name: "no callbacks", + callbacks: []func(bool){nil}, + timeout: false, + wantResult: 0, + wantCallbackReasonIsTimeout: false, + }, + } + + id := FragmentID{ID: 0} + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + result = 0 + callbackReasonIsTimeout = false + + f := NewFragmentation(minBlockSize, HighFragThreshold, LowFragThreshold, reassembleTimeout, &faketime.NullClock{}) + + for i, cb := range test.callbacks { + _, _, _, err := f.Process(id, uint16(i), uint16(i), true, proto, vv(1, "0"), cb) + if err != nil { + t.Errorf("f.Process error = %s", err) + } + } + + r, ok := f.reassemblers[id] + if !ok { + t.Fatalf("Reassemberr not found") + } + f.release(r, test.timeout) + + if result != test.wantResult { + t.Errorf("got result = %d, want = %d", result, test.wantResult) + } + if callbackReasonIsTimeout != test.wantCallbackReasonIsTimeout { + t.Errorf("got callbackReasonIsTimeout = %t, want = %t", callbackReasonIsTimeout, test.wantCallbackReasonIsTimeout) + } + }) + } +} diff --git a/pkg/tcpip/network/fragmentation/reassembler.go b/pkg/tcpip/network/fragmentation/reassembler.go index 9bb051a30..c0cc0bde0 100644 --- a/pkg/tcpip/network/fragmentation/reassembler.go +++ b/pkg/tcpip/network/fragmentation/reassembler.go @@ -41,6 +41,7 @@ type reassembler struct { heap fragHeap done bool creationTime int64 + callback func(bool) } func newReassembler(id FragmentID, clock tcpip.Clock) *reassembler { @@ -123,3 +124,24 @@ func (r *reassembler) checkDoneOrMark() bool { r.mu.Unlock() return prev } + +func (r *reassembler) setCallback(c func(bool)) bool { + r.mu.Lock() + defer r.mu.Unlock() + if r.callback != nil { + return false + } + r.callback = c + return true +} + +func (r *reassembler) release(timedOut bool) { + r.mu.Lock() + callback := r.callback + r.callback = nil + r.mu.Unlock() + + if callback != nil { + callback(timedOut) + } +} diff --git a/pkg/tcpip/network/fragmentation/reassembler_test.go b/pkg/tcpip/network/fragmentation/reassembler_test.go index a0a04a027..fa2a70dc8 100644 --- a/pkg/tcpip/network/fragmentation/reassembler_test.go +++ b/pkg/tcpip/network/fragmentation/reassembler_test.go @@ -105,3 +105,26 @@ func TestUpdateHoles(t *testing.T) { } } } + +func TestSetCallback(t *testing.T) { + result := 0 + reasonTimeout := false + + cb1 := func(timedOut bool) { result = 1; reasonTimeout = timedOut } + cb2 := func(timedOut bool) { result = 2; reasonTimeout = timedOut } + + r := newReassembler(FragmentID{}, &faketime.NullClock{}) + if !r.setCallback(cb1) { + t.Errorf("setCallback failed") + } + if r.setCallback(cb2) { + t.Errorf("setCallback should fail if one is already set") + } + r.release(true) + if result != 1 { + t.Errorf("got result = %d, want = 1", result) + } + if !reasonTimeout { + t.Errorf("got reasonTimeout = %t, want = true", reasonTimeout) + } +} diff --git a/pkg/tcpip/network/ip_test.go b/pkg/tcpip/network/ip_test.go index f20b94d97..8873bd91f 100644 --- a/pkg/tcpip/network/ip_test.go +++ b/pkg/tcpip/network/ip_test.go @@ -110,8 +110,9 @@ func (t *testObject) checkValues(protocol tcpip.TransportProtocolNumber, vv buff // DeliverTransportPacket is called by network endpoints after parsing incoming // packets. This is used by the test object to verify that the results of the // parsing are expected. -func (t *testObject) DeliverTransportPacket(r *stack.Route, protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { - t.checkValues(protocol, pkt.Data, r.RemoteAddress, r.LocalAddress) +func (t *testObject) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *stack.PacketBuffer) stack.TransportPacketDisposition { + netHdr := pkt.Network() + t.checkValues(protocol, pkt.Data, netHdr.SourceAddress(), netHdr.DestinationAddress()) t.dataCalls++ return stack.TransportPacketHandled } @@ -304,6 +305,10 @@ func (t *testInterface) setEnabled(v bool) { t.mu.disabled = !v } +func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + func TestSourceAddressValidation(t *testing.T) { rxIPv4ICMP := func(e *channel.Endpoint, src tcpip.Address) { totalLen := header.IPv4MinimumSize + header.ICMPv4MinimumSize @@ -604,7 +609,8 @@ func TestIPv4Receive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } @@ -690,6 +696,10 @@ func TestIPv4ReceiveControl(t *testing.T) { view[i] = uint8(i) } + icmp.SetChecksum(0) + checksum := ^header.Checksum(icmp, 0 /* initial */) + icmp.SetChecksum(checksum) + // Give packet to IPv4 endpoint, dispatcher will validate that // it's ok. nic.testObject.protocol = 10 @@ -699,7 +709,9 @@ func TestIPv4ReceiveControl(t *testing.T) { nic.testObject.typ = c.expectedTyp nic.testObject.extra = c.expectedExtra - ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv4MinimumSize)) + pkt := truncatedPacket(view, c.trunc, header.IPv4MinimumSize) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } @@ -780,7 +792,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 0 { t.Fatalf("Bad number of data calls: got %x, want 0", nic.testObject.dataCalls) } @@ -792,7 +805,8 @@ func TestIPv4FragmentationReceive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } @@ -892,7 +906,8 @@ func TestIPv6Receive(t *testing.T) { if _, _, ok := proto.Parse(pkt); !ok { t.Fatalf("failed to parse packet: %x", pkt.Data.ToView()) } - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if nic.testObject.dataCalls != 1 { t.Fatalf("Bad number of data calls: got %x, want 1", nic.testObject.dataCalls) } @@ -1009,7 +1024,9 @@ func TestIPv6ReceiveControl(t *testing.T) { // Set ICMPv6 checksum. icmp.SetChecksum(header.ICMPv6Checksum(icmp, outerSrcAddr, localIPv6Addr, buffer.VectorisedView{})) - ep.HandlePacket(&r, truncatedPacket(view, c.trunc, header.IPv6MinimumSize)) + pkt := truncatedPacket(view, c.trunc, header.IPv6MinimumSize) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) if want := c.expectedCount; nic.testObject.controlCalls != want { t.Fatalf("Bad number of control calls for %q case: got %v, want %v", c.name, nic.testObject.controlCalls, want) } @@ -1063,7 +1080,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum tcpip.NetworkProtocolNumber nicAddr tcpip.Address remoteAddr tcpip.Address - pktGen func(*testing.T, tcpip.Address) buffer.View + pktGen func(*testing.T, tcpip.Address) buffer.VectorisedView checker func(*testing.T, *stack.PacketBuffer, tcpip.Address) expectedErr *tcpip.Error }{ @@ -1073,7 +1090,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1087,7 +1104,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv4Any { @@ -1115,7 +1132,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv4MinimumSize + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1129,7 +1146,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, expectedErr: tcpip.ErrMalformedHeader, }, @@ -1139,7 +1156,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, @@ -1148,7 +1165,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip[:len(ip)-1]) + return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, expectedErr: tcpip.ErrMalformedHeader, }, @@ -1158,7 +1175,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) ip.Encode(&header.IPv4Fields{ IHL: header.IPv4MinimumSize, @@ -1167,7 +1184,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip) + return buffer.View(ip).ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv4Any { @@ -1195,7 +1212,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv4.ProtocolNumber, nicAddr: localIPv4Addr, remoteAddr: remoteIPv4Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ipHdrLen := header.IPv4MinimumSize + len(ipv4Options) totalLen := ipHdrLen + len(data) hdr := buffer.NewPrependable(totalLen) @@ -1213,7 +1230,49 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { if n := copy(ip.Options(), ipv4Options); n != len(ipv4Options) { t.Fatalf("copied %d bytes, expected %d bytes", n, len(ipv4Options)) } - return hdr.View() + return hdr.View().ToVectorisedView() + }, + checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { + if src == header.IPv4Any { + src = localIPv4Addr + } + + netHdr := pkt.NetworkHeader() + + hdrLen := header.IPv4MinimumSize + len(ipv4Options) + if len(netHdr.View()) != hdrLen { + t.Errorf("got len(netHdr.View()) = %d, want = %d", len(netHdr.View()), hdrLen) + } + + checker.IPv4(t, stack.PayloadSince(netHdr), + checker.SrcAddr(src), + checker.DstAddr(remoteIPv4Addr), + checker.IPv4HeaderLength(hdrLen), + checker.IPFullLength(uint16(hdrLen+len(data))), + checker.IPv4Options(ipv4Options), + checker.IPPayload(data), + ) + }, + }, + { + name: "IPv4 with options and data across views", + protoFactory: ipv4.NewProtocol, + protoNum: ipv4.ProtocolNumber, + nicAddr: localIPv4Addr, + remoteAddr: remoteIPv4Addr, + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { + ip := header.IPv4(make([]byte, header.IPv4MinimumSize)) + ip.Encode(&header.IPv4Fields{ + IHL: uint8(header.IPv4MinimumSize + len(ipv4Options)), + Protocol: transportProto, + TTL: ipv4.DefaultTTL, + SrcAddr: src, + DstAddr: header.IPv4Any, + }) + vv := buffer.View(ip).ToVectorisedView() + vv.AppendView(ipv4Options) + vv.AppendView(data) + return vv }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv4Any { @@ -1243,7 +1302,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1256,7 +1315,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv6Any { @@ -1283,7 +1342,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { totalLen := header.IPv6MinimumSize + len(ipv6FragmentExtHdr) + len(data) hdr := buffer.NewPrependable(totalLen) if n := copy(hdr.Prepend(len(data)), data); n != len(data) { @@ -1299,7 +1358,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return hdr.View() + return hdr.View().ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv6Any { @@ -1326,7 +1385,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ NextHeader: transportProto, @@ -1334,7 +1393,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip) + return buffer.View(ip).ToVectorisedView() }, checker: func(t *testing.T, pkt *stack.PacketBuffer, src tcpip.Address) { if src == header.IPv6Any { @@ -1361,7 +1420,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { protoNum: ipv6.ProtocolNumber, nicAddr: localIPv6Addr, remoteAddr: remoteIPv6Addr, - pktGen: func(t *testing.T, src tcpip.Address) buffer.View { + pktGen: func(t *testing.T, src tcpip.Address) buffer.VectorisedView { ip := header.IPv6(make([]byte, header.IPv6MinimumSize)) ip.Encode(&header.IPv6Fields{ NextHeader: transportProto, @@ -1369,7 +1428,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { SrcAddr: src, DstAddr: header.IPv4Any, }) - return buffer.View(ip[:len(ip)-1]) + return buffer.View(ip[:len(ip)-1]).ToVectorisedView() }, expectedErr: tcpip.ErrMalformedHeader, }, @@ -1413,7 +1472,7 @@ func TestWriteHeaderIncludedPacket(t *testing.T) { defer r.Release() if err := r.WriteHeaderIncludedPacket(stack.NewPacketBuffer(stack.PacketBufferOptions{ - Data: test.pktGen(t, subTest.srcAddr).ToVectorisedView(), + Data: test.pktGen(t, subTest.srcAddr), })); err != test.expectedErr { t.Fatalf("got r.WriteHeaderIncludedPacket(_) = %s, want = %s", err, test.expectedErr) } diff --git a/pkg/tcpip/network/ipv4/BUILD b/pkg/tcpip/network/ipv4/BUILD index 7fc12e229..6252614ec 100644 --- a/pkg/tcpip/network/ipv4/BUILD +++ b/pkg/tcpip/network/ipv4/BUILD @@ -29,6 +29,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", diff --git a/pkg/tcpip/network/ipv4/icmp.go b/pkg/tcpip/network/ipv4/icmp.go index 3407755ed..9b5e37fee 100644 --- a/pkg/tcpip/network/ipv4/icmp.go +++ b/pkg/tcpip/network/ipv4/icmp.go @@ -15,6 +15,7 @@ package ipv4 import ( + "errors" "fmt" "gvisor.dev/gvisor/pkg/tcpip" @@ -23,10 +24,10 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/stack" ) -// handleControl handles the case when an ICMP packet contains the headers of -// the original packet that caused the ICMP one to be sent. This information is -// used to find out which transport endpoint must be notified about the ICMP -// packet. +// handleControl handles the case when an ICMP error packet contains the headers +// of the original packet that caused the ICMP one to be sent. This information +// is used to find out which transport endpoint must be notified about the ICMP +// packet. We only expect the payload, not the enclosing ICMP packet. func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { h, ok := pkt.Data.PullUp(header.IPv4MinimumSize) if !ok { @@ -41,8 +42,8 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // // Drop packet if it doesn't have the basic IPv4 header or if the // original source address doesn't match an address we own. - src := hdr.SourceAddress() - if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, src) == 0 { + srcAddr := hdr.SourceAddress() + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, srcAddr) == 0 { return } @@ -57,11 +58,11 @@ func (e *endpoint) handleControl(typ stack.ControlType, extra uint32, pkt *stack // Skip the ip header, then deliver control message. pkt.Data.TrimFront(hlen) p := hdr.TransportProtocol() - e.dispatcher.DeliverTransportControlPacket(src, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) + e.dispatcher.DeliverTransportControlPacket(srcAddr, hdr.DestinationAddress(), ProtocolNumber, p, typ, extra, pkt) } -func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { - stats := r.Stats() +func (e *endpoint) handleICMP(pkt *stack.PacketBuffer) { + stats := e.protocol.stack.Stats() received := stats.ICMP.V4PacketsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their // TransportHeader fields set. See icmp/protocol.go:protocol.Parse for a @@ -73,20 +74,65 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { } h := header.ICMPv4(v) + // Only do in-stack processing if the checksum is correct. + if header.ChecksumVV(pkt.Data, 0 /* initial */) != 0xffff { + received.Invalid.Increment() + // It's possible that a raw socket expects to receive this regardless + // of checksum errors. If it's an echo request we know it's safe because + // we are the only handler, however other types do not cope well with + // packets with checksum errors. + switch h.Type() { + case header.ICMPv4Echo: + e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) + } + return + } + + iph := header.IPv4(pkt.NetworkHeader().View()) + var newOptions header.IPv4Options + if len(iph) > header.IPv4MinimumSize { + // RFC 1122 section 3.2.2.6 (page 43) (and similar for other round trip + // type ICMP packets): + // If a Record Route and/or Time Stamp option is received in an + // ICMP Echo Request, this option (these options) SHOULD be + // updated to include the current host and included in the IP + // header of the Echo Reply message, without "truncation". + // Thus, the recorded route will be for the entire round trip. + // + // So we need to let the option processor know how it should handle them. + var op optionsUsage + if h.Type() == header.ICMPv4Echo { + op = &optionUsageEcho{} + } else { + op = &optionUsageReceive{} + } + aux, tmp, err := e.processIPOptions(pkt, iph.Options(), op) + if err != nil { + switch { + case + errors.Is(err, header.ErrIPv4OptDuplicate), + errors.Is(err, errIPv4RecordRouteOptInvalidLength), + errors.Is(err, errIPv4RecordRouteOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptOverflow): + _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) + stats.MalformedRcvdPackets.Increment() + stats.IP.MalformedPacketsReceived.Increment() + } + return + } + newOptions = tmp + } + // TODO(b/112892170): Meaningfully handle all ICMP types. switch h.Type() { case header.ICMPv4Echo: received.Echo.Increment() - // Only send a reply if the checksum is valid. - headerChecksum := h.Checksum() - h.SetChecksum(0) - calculatedChecksum := ^header.ChecksumVV(pkt.Data, 0 /* initial */) - h.SetChecksum(headerChecksum) - if calculatedChecksum != headerChecksum { - // It's possible that a raw socket still expects to receive this. - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) - received.Invalid.Increment() + sent := stats.ICMP.V4PacketsSent + if !e.protocol.stack.AllowICMPMessage() { + sent.RateLimited.Increment() return } @@ -98,19 +144,27 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // waiting endpoints. Consider moving responsibility for doing the copy to // DeliverTransportPacket so that is is only done when needed. replyData := pkt.Data.ToOwnedView() - replyIPHdr := header.IPv4(append(buffer.View(nil), pkt.NetworkHeader().View()...)) + ipHdr := header.IPv4(pkt.NetworkHeader().View()) + localAddressBroadcast := pkt.NetworkPacketInfo.LocalAddressBroadcast + + // It's possible that a raw socket expects to receive this. + e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) + pkt = nil - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + // Take the base of the incoming request IP header but replace the options. + replyHeaderLength := uint8(header.IPv4MinimumSize + len(newOptions)) + replyIPHdr := header.IPv4(append(iph[:header.IPv4MinimumSize:header.IPv4MinimumSize], newOptions...)) + replyIPHdr.SetHeaderLength(replyHeaderLength) // As per RFC 1122 section 3.2.1.3, when a host sends any datagram, the IP // source address MUST be one of its own IP addresses (but not a broadcast // or multicast address). - localAddr := r.LocalAddress - if r.IsInboundBroadcast() || header.IsV4MulticastAddress(localAddr) { + localAddr := ipHdr.DestinationAddress() + if localAddressBroadcast || header.IsV4MulticastAddress(localAddr) { localAddr = "" } - r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, ipHdr.SourceAddress(), ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. return @@ -139,7 +193,8 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { // The fields we need to alter. // // We need to produce the entire packet in the data segment in order to - // use WriteHeaderIncludedPacket(). + // use WriteHeaderIncludedPacket(). WriteHeaderIncludedPacket sets the + // total length and the header checksum so we don't need to set those here. replyIPHdr.SetSourceAddress(r.LocalAddress) replyIPHdr.SetDestinationAddress(r.RemoteAddress) replyIPHdr.SetTTL(r.DefaultTTL()) @@ -157,8 +212,6 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { }) replyPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber - // The checksum will be calculated so we don't need to do it here. - sent := stats.ICMP.V4PacketsSent if err := r.WriteHeaderIncludedPacket(replyPkt); err != nil { sent.Dropped.Increment() return @@ -168,7 +221,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { case header.ICMPv4EchoReply: received.EchoReply.Increment() - e.dispatcher.DeliverTransportPacket(r, header.ICMPv4ProtocolNumber, pkt) + e.dispatcher.DeliverTransportPacket(header.ICMPv4ProtocolNumber, pkt) case header.ICMPv4DstUnreachable: received.DstUnreachable.Increment() @@ -182,8 +235,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer) { e.handleControl(stack.ControlPortUnreachable, 0, pkt) case header.ICMPv4FragmentationNeeded: - mtu := uint32(h.MTU()) - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) + networkMTU, err := calculateNetworkMTU(uint32(h.MTU()), header.IPv4MinimumSize) + if err != nil { + networkMTU = 0 + } + e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) } case header.ICMPv4SrcQuench: @@ -234,12 +290,31 @@ type icmpReasonProtoUnreachable struct{} func (*icmpReasonProtoUnreachable) isICMPReason() {} +// icmpReasonReassemblyTimeout is an error where insufficient fragments are +// received to complete reassembly of a packet within a configured time after +// the reception of the first-arriving fragment of that packet. +type icmpReasonReassemblyTimeout struct{} + +func (*icmpReasonReassemblyTimeout) isICMPReason() {} + +// icmpReasonParamProblem is an error to use to request a Parameter Problem +// message to be sent. +type icmpReasonParamProblem struct { + pointer byte +} + +func (*icmpReasonParamProblem) isICMPReason() {} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv4 and sends it back to the remote device that sent // the problematic packet. It incorporates as much of that packet as // possible as well as any error metadata as is available. returnError // expects pkt to hold a valid IPv4 packet as per the wire format. -func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + origIPHdr := header.IPv4(pkt.NetworkHeader().View()) + origIPHdrSrc := origIPHdr.SourceAddress() + origIPHdrDst := origIPHdr.DestinationAddress() + // We check we are responding only when we are allowed to. // See RFC 1812 section 4.3.2.7 (shown below). // @@ -263,8 +338,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // // TODO(gvisor.dev/issues/4058): Make sure we don't send ICMP errors in // response to a non-initial fragment, but it currently can not happen. - - if r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || r.RemoteAddress == header.IPv4Any { + if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(origIPHdrDst) || origIPHdrSrc == header.IPv4Any { return nil } @@ -272,14 +346,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any // packet. - route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } defer route.Release() - // From this point on, the incoming route should no longer be used; route - // must be used to send the ICMP error. - r = nil sent := p.stack.Stats().ICMP.V4PacketsSent if !p.stack.AllowICMPMessage() { @@ -287,11 +358,10 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac return nil } - networkHeader := pkt.NetworkHeader().View() transportHeader := pkt.TransportHeader().View() // Don't respond to icmp error packets. - if header.IPv4(networkHeader).Protocol() == uint8(header.ICMPv4ProtocolNumber) { + if origIPHdr.Protocol() == uint8(header.ICMPv4ProtocolNumber) { // TODO(gvisor.dev/issue/3810): // Unfortunately the current stack pretty much always has ICMPv4 headers // in the Data section of the packet but there is no guarantee that is the @@ -348,7 +418,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac return nil } - payloadLen := networkHeader.Size() + transportHeader.Size() + pkt.Data.Size() + payloadLen := len(origIPHdr) + transportHeader.Size() + pkt.Data.Size() if payloadLen > available { payloadLen = available } @@ -360,7 +430,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // view with the entire incoming IP packet reassembled and truncated as // required. This is now the payload of the new ICMP packet and no longer // considered a packet in its own right. - newHeader := append(buffer.View(nil), networkHeader...) + newHeader := append(buffer.View(nil), origIPHdr...) newHeader = append(newHeader, transportHeader...) payload := newHeader.ToVectorisedView() payload.AppendView(pkt.Data.ToView()) @@ -374,17 +444,29 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac icmpPkt.TransportProtocolNumber = header.ICMPv4ProtocolNumber icmpHdr := header.ICMPv4(icmpPkt.TransportHeader().Push(header.ICMPv4MinimumSize)) - switch reason.(type) { + var counter *tcpip.StatCounter + switch reason := reason.(type) { case *icmpReasonPortUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4PortUnreachable) + counter = sent.DstUnreachable case *icmpReasonProtoUnreachable: + icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetCode(header.ICMPv4ProtoUnreachable) + counter = sent.DstUnreachable + case *icmpReasonReassemblyTimeout: + icmpHdr.SetType(header.ICMPv4TimeExceeded) + icmpHdr.SetCode(header.ICMPv4ReassemblyTimeout) + counter = sent.TimeExceeded + case *icmpReasonParamProblem: + icmpHdr.SetType(header.ICMPv4ParamProblem) + icmpHdr.SetCode(header.ICMPv4UnusedCode) + icmpHdr.SetPointer(reason.pointer) + counter = sent.ParamProblem default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } - icmpHdr.SetType(header.ICMPv4DstUnreachable) icmpHdr.SetChecksum(header.ICMPv4Checksum(icmpHdr, icmpPkt.Data)) - counter := sent.DstUnreachable if err := route.WritePacket( nil, /* gso */ diff --git a/pkg/tcpip/network/ipv4/ipv4.go b/pkg/tcpip/network/ipv4/ipv4.go index e7c58ae0a..cfd0c505a 100644 --- a/pkg/tcpip/network/ipv4/ipv4.go +++ b/pkg/tcpip/network/ipv4/ipv4.go @@ -16,7 +16,9 @@ package ipv4 import ( + "errors" "fmt" + "math" "sync/atomic" "time" @@ -31,6 +33,8 @@ import ( ) const ( + // ReassembleTimeout is the time a packet stays in the reassembly + // system before being evicted. // As per RFC 791 section 3.2: // The current recommendation for the initial timer setting is 15 seconds. // This may be changed as experience with this protocol accumulates. @@ -38,7 +42,7 @@ const ( // Considering that it is an old recommendation, we use the same reassembly // timeout that linux defines, which is 30 seconds: // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ip.h#L138 - reassembleTimeout = 30 * time.Second + ReassembleTimeout = 30 * time.Second // ProtocolNumber is the ipv4 protocol number. ProtocolNumber = header.IPv4ProtocolNumber @@ -176,7 +180,11 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.nic.MTU()) + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv4MinimumSize) + if err != nil { + return 0 + } + return networkMTU } // MaxHeaderLength returns the maximum length needed by ipv4 headers (and @@ -211,18 +219,15 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s pkt.NetworkProtocolNumber = ProtocolNumber } -func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { - return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU()) -} - // handleFragments fragments pkt and calls the handler function on each // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the -// original packet. The mtu is the maximum size of the packets. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { - fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) +// original packet. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + // Round the MTU down to align to 8 bytes. + fragmentPayloadSize := networkMTU &^ 7 networkHeader := header.IPv4(pkt.NetworkHeader().View()) - pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, pkt.AvailableHeaderBytes()+len(networkHeader)) + pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadSize, pkt.AvailableHeaderBytes()+len(networkHeader)) var n int for { @@ -247,8 +252,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet // iptables filtering. All packets that reach here are locally // generated. nicName := e.protocol.stack.FindNICNameFromID(e.nic.ID()) - ipt := e.protocol.stack.IPTables() - if ok := ipt.Check(stack.Output, pkt, gso, r, "", nicName); !ok { + if ok := e.protocol.stack.IPTables().Check(stack.Output, pkt, gso, r, "", nicName); !ok { // iptables is telling us to drop the packet. r.Stats().IP.IPTablesOutputDropped.Increment() return nil @@ -265,23 +269,40 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet netHeader := header.IPv4(pkt.NetworkHeader().View()) ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()) if err == nil { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + route.PopulatePacketInfo(pkt) + // Since we rewrote the packet but it is being routed back to us, we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.HandlePacket(pkt) + } return nil } } if r.Loop&stack.PacketLoop != 0 { - loopedR := r.MakeLoopedRoute() - e.HandlePacket(&loopedR, pkt) - loopedR.Release() + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + loopedR := r.MakeLoopedRoute() + loopedR.PopulatePacketInfo(pkt) + loopedR.Release() + e.HandlePacket(pkt) + } } if r.Loop&stack.PacketOut == 0 { return nil } - if e.packetMustBeFragmented(pkt, gso) { - sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to @@ -292,6 +313,7 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(remain)) return err } + if err := e.nic.WritePacket(r, gso, ProtocolNumber, pkt); err != nil { r.Stats().IP.OutgoingPacketErrors.Increment() return err @@ -311,17 +333,23 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe for pkt := pkts.Front(); pkt != nil; pkt = pkt.Next() { e.addIPHeader(r, pkt, params) - if e.packetMustBeFragmented(pkt, gso) { + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + return 0, err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pkt - if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + if _, _, err := e.handleFragments(r, gso, networkMTU, pkt, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pkt, fragPkt) pkt = fragPkt return nil }); err != nil { - panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", e.nic.MTU(), err)) + panic(fmt.Sprintf("e.handleFragments(_, _, %d, _, _) = %s", networkMTU, err)) } // Remove the packet that was just fragmented and process the rest. pkts.Remove(originalPkt) @@ -355,10 +383,12 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if _, ok := natPkts[pkt]; ok { netHeader := header.IPv4(pkt.NetworkHeader().View()) if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - src := netHeader.SourceAddress() - dst := netHeader.DestinationAddress() - route := r.ReverseRoute(src, dst) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + route.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) + } n++ continue } @@ -385,6 +415,16 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu if !ok { return tcpip.ErrMalformedHeader } + + hdrLen := header.IPv4(h).HeaderLength() + if hdrLen < header.IPv4MinimumSize { + return tcpip.ErrMalformedHeader + } + + h, ok = pkt.Data.PullUp(int(hdrLen)) + if !ok { + return tcpip.ErrMalformedHeader + } ip := header.IPv4(h) // Always set the total length. @@ -429,14 +469,17 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // HandlePacket is called by the link layer when new ipv4 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !e.isEnabled() { return } + pkt.NICID = e.nic.ID() + stats := e.protocol.stack.Stats() + h := header.IPv4(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } @@ -462,7 +505,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // is all 1 bits (-0 in 1's complement arithmetic), the check // succeeds. if h.CalculateChecksum() != 0xffff { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } @@ -470,8 +513,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // When a host sends any datagram, the IP source address MUST // be one of its own IP addresses (but not a broadcast or // multicast address). - if r.IsOutboundBroadcast() || header.IsV4MulticastAddress(r.RemoteAddress) { - r.Stats().IP.InvalidSourceAddressesReceived.Increment() + if pkt.NetworkPacketInfo.RemoteAddressBroadcast || header.IsV4MulticastAddress(h.SourceAddress()) { + stats.IP.InvalidSourceAddressesReceived.Increment() return } @@ -480,7 +523,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesInputDropped.Increment() + stats.IP.IPTablesInputDropped.Increment() return } @@ -488,8 +531,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { if pkt.Data.Size()+pkt.TransportHeader().View().Size() == 0 { // Drop the packet as it's marked as a fragment but has // no payload. - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } // The packet is a fragment, let's try to reassemble it. @@ -502,10 +545,30 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // size). Otherwise the packet would've been rejected as invalid before // reaching here. if int(start)+pkt.Data.Size() > header.IPv4MaximumPayloadSize { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } + + // Set up a callback in case we need to send a Time Exceeded Message, as per + // RFC 792: + // + // If a host reassembling a fragmented datagram cannot complete the + // reassembly due to missing fragments within its time limit it discards + // the datagram, and it may send a time exceeded message. + // + // If fragment zero is not available then no time exceeded need be sent at + // all. + var releaseCB func(bool) + if start == 0 { + pkt := pkt.Clone() + releaseCB = func(timedOut bool) { + if timedOut { + _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt) + } + } + } + var ready bool var err error proto := h.Protocol() @@ -523,29 +586,56 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { h.More(), proto, pkt.Data, + releaseCB, ) if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } if !ready { return } + + // The reassembler doesn't take care of fixing up the header, so we need + // to do it here. + h.SetTotalLength(uint16(pkt.Data.Size() + len((h)))) + h.SetFlagsFragmentOffset(0, 0) } + stats.IP.PacketsDelivered.Increment() - r.Stats().IP.PacketsDelivered.Increment() p := h.TransportProtocol() if p == header.ICMPv4ProtocolNumber { // TODO(gvisor.dev/issues/3810): when we sort out ICMP and transport // headers, the setting of the transport number here should be // unnecessary and removed. pkt.TransportProtocolNumber = p - e.handleICMP(r, pkt) + e.handleICMP(pkt) return } + if len(h.Options()) != 0 { + // TODO(gvisor.dev/issue/4586): + // When we add forwarding support we should use the verified options + // rather than just throwing them away. + aux, _, err := e.processIPOptions(pkt, h.Options(), &optionUsageReceive{}) + if err != nil { + switch { + case + errors.Is(err, header.ErrIPv4OptDuplicate), + errors.Is(err, errIPv4RecordRouteOptInvalidPointer), + errors.Is(err, errIPv4RecordRouteOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidLength), + errors.Is(err, errIPv4TimestampOptInvalidPointer), + errors.Is(err, errIPv4TimestampOptOverflow): + _ = e.protocol.returnError(&icmpReasonParamProblem{pointer: aux}, pkt) + stats.MalformedRcvdPackets.Increment() + stats.IP.MalformedPacketsReceived.Increment() + } + return + } + } - switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { + switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: // As per RFC: 1122 Section 3.2.2.1 A host SHOULD generate Destination @@ -553,13 +643,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // 3 (Port Unreachable), when the designated transport protocol // (e.g., UDP) is unable to demultiplex the datagram but has no // protocol mechanism to inform the sender. - _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt) case stack.TransportPacketProtocolUnreachable: // As per RFC: 1122 Section 3.2.2.1 // A host SHOULD generate Destination Unreachable messages with code: // 2 (Protocol Unreachable), when the designated transport protocol // is not supported - _ = e.protocol.returnError(r, &icmpReasonProtoUnreachable{}, pkt) + _ = e.protocol.returnError(&icmpReasonProtoUnreachable{}, pkt) default: panic(fmt.Sprintf("unrecognized result from DeliverTransportPacket = %d", res)) } @@ -602,7 +692,7 @@ func (e *endpoint) AcquireAssignedAddress(localAddr tcpip.Address, allowTemp boo loopback := e.nic.IsLoopback() addressEndpoint := e.mu.addressableEndpointState.ReadOnly().AddrOrMatching(localAddr, allowTemp, func(addressEndpoint stack.AddressEndpoint) bool { - subnet := addressEndpoint.AddressWithPrefix().Subnet() + subnet := addressEndpoint.Subnet() // IPv4 has a notion of a subnet broadcast address and considers the // loopback interface bound to an address's whole subnet (on linux). return subnet.IsBroadcast(localAddr) || (loopback && subnet.Contains(localAddr)) @@ -778,26 +868,32 @@ func (p *protocol) SetForwarding(v bool) { } } -// calculateMTU calculates the network-layer payload MTU based on the link-layer -// payload mtu. -func calculateMTU(mtu uint32) uint32 { - if mtu > MaxTotalSize { - mtu = MaxTotalSize +// calculateNetworkMTU calculates the network-layer payload MTU based on the +// link-layer payload mtu. +func calculateNetworkMTU(linkMTU, networkHeaderSize uint32) (uint32, *tcpip.Error) { + if linkMTU < header.IPv4MinimumMTU { + return 0, tcpip.ErrInvalidEndpointState } - return mtu - header.IPv4MinimumSize -} -// calculateFragmentInnerMTU calculates the maximum number of bytes of -// fragmentable data a fragment can have, based on the link layer mtu and pkt's -// network header size. -func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 { - if mtu > MaxTotalSize { - mtu = MaxTotalSize + // As per RFC 791 section 3.1, an IPv4 header cannot exceed 60 bytes in + // length: + // The maximal internet header is 60 octets, and a typical internet header + // is 20 octets, allowing a margin for headers of higher level protocols. + if networkHeaderSize > header.IPv4MaximumHeaderSize { + return 0, tcpip.ErrMalformedHeader } - mtu -= uint32(pkt.NetworkHeader().View().Size()) - // Round the MTU down to align to 8 bytes. - mtu &^= 7 - return mtu + + networkMTU := linkMTU + if networkMTU > MaxTotalSize { + networkMTU = MaxTotalSize + } + + return networkMTU - uint32(networkHeaderSize), nil +} + +func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { + payload := pkt.TransportHeader().View().Size() + pkt.Data.Size() + return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU } // addressToUint32 translates an IPv4 address into its little endian uint32 @@ -836,7 +932,7 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { ids: ids, hashIV: hashIV, defaultTTL: DefaultTTL, - fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()), + fragmentation: fragmentation.NewFragmentation(fragmentblockSize, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()), } } @@ -846,6 +942,7 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader head originalIPHeaderLength := len(originalIPHeader) nextFragIPHeader := header.IPv4(fragPkt.NetworkHeader().Push(originalIPHeaderLength)) + fragPkt.NetworkProtocolNumber = ProtocolNumber if copied := copy(nextFragIPHeader, originalIPHeader); copied != len(originalIPHeader) { panic(fmt.Sprintf("wrong number of bytes copied into fragmentIPHeaders: got = %d, want = %d", copied, originalIPHeaderLength)) @@ -862,3 +959,324 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeader head return fragPkt, more } + +// optionAction describes possible actions that may be taken on an option +// while processing it. +type optionAction uint8 + +const ( + // optionRemove says that the option should not be in the output option set. + optionRemove optionAction = iota + + // optionProcess says that the option should be fully processed. + optionProcess + + // optionVerify says the option should be checked and passed unchanged. + optionVerify + + // optionPass says to pass the output set without checking. + optionPass +) + +// optionActions list what to do for each option in a given scenario. +type optionActions struct { + // timestamp controls what to do with a Timestamp option. + timestamp optionAction + + // recordroute controls what to do with a Record Route option. + recordRoute optionAction + + // unknown controls what to do with an unknown option. + unknown optionAction +} + +// optionsUsage specifies the ways options may be operated upon for a given +// scenario during packet processing. +type optionsUsage interface { + actions() optionActions +} + +// optionUsageReceive implements optionsUsage for received packets. +type optionUsageReceive struct{} + +// actions implements optionsUsage. +func (*optionUsageReceive) actions() optionActions { + return optionActions{ + timestamp: optionVerify, + recordRoute: optionVerify, + unknown: optionPass, + } +} + +// TODO(gvisor.dev/issue/4586): Add an entry here for forwarding when it +// is enabled (Process, Process, Pass) and for fragmenting (Process, Process, +// Pass for frag1, but Remove,Remove,Remove for all other frags). + +// optionUsageEcho implements optionsUsage for echo packet processing. +type optionUsageEcho struct{} + +// actions implements optionsUsage. +func (*optionUsageEcho) actions() optionActions { + return optionActions{ + timestamp: optionProcess, + recordRoute: optionProcess, + unknown: optionRemove, + } +} + +var ( + errIPv4TimestampOptInvalidLength = errors.New("invalid Timestamp length") + errIPv4TimestampOptInvalidPointer = errors.New("invalid Timestamp pointer") + errIPv4TimestampOptOverflow = errors.New("overflow in Timestamp") + errIPv4TimestampOptInvalidFlags = errors.New("invalid Timestamp flags") +) + +// handleTimestamp does any required processing on a Timestamp option +// in place. +func handleTimestamp(tsOpt header.IPv4OptionTimestamp, localAddress tcpip.Address, clock tcpip.Clock, usage optionsUsage) (uint8, error) { + flags := tsOpt.Flags() + var entrySize uint8 + switch flags { + case header.IPv4OptionTimestampOnlyFlag: + entrySize = header.IPv4OptionTimestampSize + case + header.IPv4OptionTimestampWithIPFlag, + header.IPv4OptionTimestampWithPredefinedIPFlag: + entrySize = header.IPv4OptionTimestampWithAddrSize + default: + return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptInvalidFlags + } + + pointer := tsOpt.Pointer() + // To simplify processing below, base further work on the array of timestamps + // beyond the header, rather than on the whole option. Also to aid + // calculations set 'nextSlot' to be 0 based as in the packet it is 1 based. + nextSlot := pointer - (header.IPv4OptionTimestampHdrLength + 1) + optLen := tsOpt.Size() + dataLength := optLen - header.IPv4OptionTimestampHdrLength + + // In the section below, we verify the pointer, length and overflow counter + // fields of the option. The distinction is in which byte you return as being + // in error in the ICMP packet. Offsets 1 (length), 2 pointer) + // or 3 (overflowed counter). + // + // The following RFC sections cover this section: + // + // RFC 791 (page 22): + // If there is some room but not enough room for a full timestamp + // to be inserted, or the overflow count itself overflows, the + // original datagram is considered to be in error and is discarded. + // In either case an ICMP parameter problem message may be sent to + // the source host [3]. + // + // You can get this situation in two ways. Firstly if the data area is not + // a multiple of the entry size or secondly, if the pointer is not at a + // multiple of the entry size. The wording of the RFC suggests that + // this is not an error until you actually run out of space. + if pointer > optLen { + // RFC 791 (page 22) says we should switch to using the overflow count. + // If the timestamp data area is already full (the pointer exceeds + // the length) the datagram is forwarded without inserting the + // timestamp, but the overflow count is incremented by one. + if flags == header.IPv4OptionTimestampWithPredefinedIPFlag { + // By definition we have nothing to do. + return 0, nil + } + + if tsOpt.IncOverflow() != 0 { + return 0, nil + } + // The overflow count is also full. + return header.IPv4OptTSOFLWAndFLGOffset, errIPv4TimestampOptOverflow + } + if nextSlot+entrySize > dataLength { + // The data area isn't full but there isn't room for a new entry. + // Either Length or Pointer could be bad. + if false { + // We must select Pointer for Linux compatibility, even if + // only the length is bad. + // The Linux code is at (in October 2020) + // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L367-L370 + // if (optptr[2]+3 > optlen) { + // pp_ptr = optptr + 2; + // goto error; + // } + // which doesn't distinguish between which of optptr[2] or optlen + // is wrong, but just arbitrarily decides on optptr+2. + if dataLength%entrySize != 0 { + // The Data section size should be a multiple of the expected + // timestamp entry size. + return header.IPv4OptionLengthOffset, errIPv4TimestampOptInvalidLength + } + // If the size is OK, the pointer must be corrupted. + } + return header.IPv4OptTSPointerOffset, errIPv4TimestampOptInvalidPointer + } + + if usage.actions().timestamp == optionProcess { + tsOpt.UpdateTimestamp(localAddress, clock) + } + return 0, nil +} + +var ( + errIPv4RecordRouteOptInvalidLength = errors.New("invalid length in Record Route") + errIPv4RecordRouteOptInvalidPointer = errors.New("invalid pointer in Record Route") +) + +// handleRecordRoute checks and processes a Record route option. It is much +// like the timestamp type 1 option, but without timestamps. The passed in +// address is stored in the option in the correct spot if possible. +func handleRecordRoute(rrOpt header.IPv4OptionRecordRoute, localAddress tcpip.Address, usage optionsUsage) (uint8, error) { + optlen := rrOpt.Size() + + if optlen < header.IPv4AddressSize+header.IPv4OptionRecordRouteHdrLength { + return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + } + + nextSlot := rrOpt.Pointer() - 1 // Pointer is 1 based. + + // RFC 791 page 21 says + // If the route data area is already full (the pointer exceeds the + // length) the datagram is forwarded without inserting the address + // into the recorded route. If there is some room but not enough + // room for a full address to be inserted, the original datagram is + // considered to be in error and is discarded. In either case an + // ICMP parameter problem message may be sent to the source + // host. + // The use of the words "In either case" suggests that a 'full' RR option + // could generate an ICMP at every hop after it fills up. We chose to not + // do this (as do most implementations). It is probable that the inclusion + // of these words is a copy/paste error from the timestamp option where + // there are two failure reasons given. + if nextSlot >= optlen { + return 0, nil + } + + // The data area isn't full but there isn't room for a new entry. + // Either Length or Pointer could be bad. We must select Pointer for Linux + // compatibility, even if only the length is bad. + if nextSlot+header.IPv4AddressSize > optlen { + if false { + // This is what we would do if we were not being Linux compatible. + // Check for bad pointer or length value. Must be a multiple of 4 after + // accounting for the 3 byte header and not within that header. + // RFC 791, page 20 says: + // The pointer is relative to this option, and the + // smallest legal value for the pointer is 4. + // + // A recorded route is composed of a series of internet addresses. + // Each internet address is 32 bits or 4 octets. + // Linux skips this test so we must too. See Linux code at: + // https://github.com/torvalds/linux/blob/bbf5c979011a099af5dc76498918ed7df445635b/net/ipv4/ip_options.c#L338-L341 + // if (optptr[2]+3 > optlen) { + // pp_ptr = optptr + 2; + // goto error; + // } + if (optlen-header.IPv4OptionRecordRouteHdrLength)%header.IPv4AddressSize != 0 { + // Length is bad, not on integral number of slots. + return header.IPv4OptionLengthOffset, errIPv4RecordRouteOptInvalidLength + } + // If not length, the fault must be with the pointer. + } + return header.IPv4OptRRPointerOffset, errIPv4RecordRouteOptInvalidPointer + } + if usage.actions().recordRoute == optionVerify { + return 0, nil + } + rrOpt.StoreAddress(localAddress) + return 0, nil +} + +// processIPOptions parses the IPv4 options and produces a new set of options +// suitable for use in the next step of packet processing as informed by usage. +// The original will not be touched. +// +// Returns +// - The location of an error if there was one (or 0 if no error) +// - If there is an error, information as to what it was was. +// - The replacement option set. +func (e *endpoint) processIPOptions(pkt *stack.PacketBuffer, orig header.IPv4Options, usage optionsUsage) (uint8, header.IPv4Options, error) { + stats := e.protocol.stack.Stats() + opts := header.IPv4Options(orig) + optIter := opts.MakeIterator() + + // Each option other than NOP must only appear (RFC 791 section 3.1, at the + // definition of every type). Keep track of each of the possible types in + // the 8 bit 'type' field. + var seenOptions [math.MaxUint8 + 1]bool + + // TODO(gvisor.dev/issue/4586): + // This will need tweaking when we start really forwarding packets + // as we may need to get two addresses, for rx and tx interfaces. + // We will also have to take usage into account. + prefixedAddress, err := e.protocol.stack.GetMainNICAddress(e.nic.ID(), ProtocolNumber) + localAddress := prefixedAddress.Address + if err != nil { + h := header.IPv4(pkt.NetworkHeader().View()) + dstAddr := h.DestinationAddress() + if pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(dstAddr) { + return 0 /* errCursor */, nil, header.ErrIPv4OptionAddress + } + localAddress = dstAddr + } + + for { + option, done, err := optIter.Next() + if done || err != nil { + return optIter.ErrCursor, optIter.Finalize(), err + } + optType := option.Type() + if optType == header.IPv4OptionNOPType { + optIter.PushNOPOrEnd(optType) + continue + } + if optType == header.IPv4OptionListEndType { + optIter.PushNOPOrEnd(optType) + return 0 /* errCursor */, optIter.Finalize(), nil /* err */ + } + + // check for repeating options (multiple NOPs are OK) + if seenOptions[optType] { + return optIter.ErrCursor, nil, header.ErrIPv4OptDuplicate + } + seenOptions[optType] = true + + optLen := int(option.Size()) + switch option := option.(type) { + case *header.IPv4OptionTimestamp: + stats.IP.OptionTSReceived.Increment() + if usage.actions().timestamp != optionRemove { + clock := e.protocol.stack.Clock() + newBuffer := optIter.RemainingBuffer()[:len(*option)] + _ = copy(newBuffer, option.Contents()) + offset, err := handleTimestamp(header.IPv4OptionTimestamp(newBuffer), localAddress, clock, usage) + if err != nil { + return optIter.ErrCursor + offset, nil, err + } + optIter.ConsumeBuffer(optLen) + } + + case *header.IPv4OptionRecordRoute: + stats.IP.OptionRRReceived.Increment() + if usage.actions().recordRoute != optionRemove { + newBuffer := optIter.RemainingBuffer()[:len(*option)] + _ = copy(newBuffer, option.Contents()) + offset, err := handleRecordRoute(header.IPv4OptionRecordRoute(newBuffer), localAddress, usage) + if err != nil { + return optIter.ErrCursor + offset, nil, err + } + optIter.ConsumeBuffer(optLen) + } + + default: + stats.IP.OptionUnknownReceived.Increment() + if usage.actions().unknown == optionPass { + newBuffer := optIter.RemainingBuffer()[:optLen] + // Arguments already heavily checked.. ignore result. + _ = copy(newBuffer, option.Contents()) + optIter.ConsumeBuffer(optLen) + } + } + } +} diff --git a/pkg/tcpip/network/ipv4/ipv4_test.go b/pkg/tcpip/network/ipv4/ipv4_test.go index fee11bb38..c7f434591 100644 --- a/pkg/tcpip/network/ipv4/ipv4_test.go +++ b/pkg/tcpip/network/ipv4/ipv4_test.go @@ -21,11 +21,13 @@ import ( "math" "net" "testing" + "time" "github.com/google/go-cmp/cmp" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/link/sniffer" @@ -39,7 +41,10 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -const extraHeaderReserve = 50 +const ( + extraHeaderReserve = 50 + defaultMTU = 65536 +) func TestExcludeBroadcast(t *testing.T) { s := stack.New(stack.Options{ @@ -47,7 +52,6 @@ func TestExcludeBroadcast(t *testing.T) { TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - const defaultMTU = 65536 ep := stack.LinkEndpoint(channel.New(256, defaultMTU, "")) if testing.Verbose() { ep = sniffer.New(ep) @@ -103,7 +107,6 @@ func TestExcludeBroadcast(t *testing.T) { // checks the response. func TestIPv4Sanity(t *testing.T) { const ( - defaultMTU = header.IPv6MinimumMTU ttl = 255 nicID = 1 randomSequence = 123 @@ -118,27 +121,29 @@ func TestIPv4Sanity(t *testing.T) { ) tests := []struct { - name string - headerLength uint8 // value of 0 means "use correct size" - badHeaderChecksum bool - maxTotalLength uint16 - transportProtocol uint8 - TTL uint8 - shouldFail bool - expectICMP bool - ICMPType header.ICMPv4Type - ICMPCode header.ICMPv4Code - options []byte + name string + headerLength uint8 // value of 0 means "use correct size" + badHeaderChecksum bool + maxTotalLength uint16 + transportProtocol uint8 + TTL uint8 + options []byte + replyOptions []byte // if succeeds, reply should look like this + shouldFail bool + expectErrorICMP bool + ICMPType header.ICMPv4Type + ICMPCode header.ICMPv4Code + paramProblemPointer uint8 }{ { - name: "valid", - maxTotalLength: defaultMTU, + name: "valid no options", + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, }, { name: "bad header checksum", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, badHeaderChecksum: true, @@ -157,47 +162,47 @@ func TestIPv4Sanity(t *testing.T) { // received with TTL less than 2. { name: "zero TTL", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: 0, - shouldFail: false, }, { name: "one TTL", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: 1, - shouldFail: false, }, { name: "End options", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: []byte{0, 0, 0, 0}, + replyOptions: []byte{0, 0, 0, 0}, }, { name: "NOP options", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: []byte{1, 1, 1, 1}, + replyOptions: []byte{1, 1, 1, 1}, }, { name: "NOP and End options", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, options: []byte{1, 1, 0, 0}, + replyOptions: []byte{1, 1, 0, 0}, }, { name: "bad header length", headerLength: header.IPv4MinimumSize - 1, - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad total length (0)", @@ -205,7 +210,6 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad total length (ip - 1)", @@ -213,7 +217,6 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad total length (ip + icmp - 1)", @@ -221,28 +224,361 @@ func TestIPv4Sanity(t *testing.T) { transportProtocol: uint8(header.ICMPv4ProtocolNumber), TTL: ttl, shouldFail: true, - expectICMP: false, }, { name: "bad protocol", - maxTotalLength: defaultMTU, + maxTotalLength: ipv4.MaxTotalSize, transportProtocol: 99, TTL: ttl, shouldFail: true, - expectICMP: true, + expectErrorICMP: true, ICMPType: header.ICMPv4DstUnreachable, ICMPCode: header.ICMPv4ProtoUnreachable, }, + { + name: "timestamp option overflow", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 12, 13, 0x11, + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + replyOptions: []byte{ + 68, 12, 13, 0x21, + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + }, + { + name: "timestamp option overflow full", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 12, 13, 0xF1, + // ^ Counter full (15/0xF) + 192, 168, 1, 12, + 1, 2, 3, 4, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 3, + replyOptions: []byte{}, + }, + { + name: "unknown option", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{10, 4, 9, 0}, + // ^^ + // The unknown option should be stripped out of the reply. + replyOptions: []byte{}, + }, + { + name: "bad option - length 0", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 0, 9, 0, + // ^ + 1, 2, 3, 4, + }, + shouldFail: true, + }, + { + name: "bad option - length big", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 9, 9, 0, + // ^ + // There are only 8 bytes allocated to options so 9 bytes of timestamp + // space is not possible. (Second byte) + 1, 2, 3, 4, + }, + shouldFail: true, + }, + { + // This tests for some linux compatible behaviour. + // The ICMP pointer returned is 22 for Linux but the + // error is actually in spot 21. + name: "bad option - length bad", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + // Timestamps are in multiples of 4 or 8 but never 7. + // The option space should be padded out. + options: []byte{ + 68, 7, 5, 0, + // ^ ^ Linux points here which is wrong. + // | Not a multiple of 4 + 1, 2, 3, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + { + name: "multiple type 0 with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 24, 21, 0x00, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0, 0, 0, 0, + }, + replyOptions: []byte{ + 68, 24, 25, 0x00, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { + // The timestamp area is full so add to the overflow count. + name: "multiple type 1 timestamps", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 20, 21, 0x11, + // ^ + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + }, + // Overflow count is the top nibble of the 4th byte. + replyOptions: []byte{ + 68, 20, 21, 0x21, + // ^ + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + }, + }, + { + name: "multiple type 1 timestamps with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 28, 21, 0x01, + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + replyOptions: []byte{ + 68, 28, 29, 0x01, + 192, 168, 1, 12, + 1, 2, 3, 4, + 192, 168, 1, 13, + 5, 6, 7, 8, + 192, 168, 1, 58, // New IP Address. + 0x00, 0xad, 0x1c, 0x40, // time we expect from fakeclock + }, + }, + { + // Needs 8 bytes for a type 1 timestamp but there are only 4 free. + name: "bad timer element alignment", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 20, 17, 0x01, + // ^^ ^^ 20 byte area, next free spot at 17. + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 0, 0, 0, + 0, 0, 0, 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + }, + // End of option list with illegal option after it, which should be ignored. + { + name: "end of options list", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 68, 12, 13, 0x11, + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 10, 3, 99, + }, + replyOptions: []byte{ + 68, 12, 13, 0x21, + 192, 168, 1, 12, + 1, 2, 3, 4, + 0, 0, 0, 0, // 3 bytes unknown option + }, // ^ End of options hides following bytes. + }, + { + // Timestamp with a size too small. + name: "timestamp truncated", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{68, 1, 0, 0}, + // ^ Smallest possible is 8. + shouldFail: true, + }, + { + name: "single record route with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 7, 4, // 3 byte header + 0, 0, 0, 0, + 0, + }, + replyOptions: []byte{ + 7, 7, 8, // 3 byte header + 192, 168, 1, 58, // New IP Address. + 0, // padding to multiple of 4 bytes. + }, + }, + { + name: "multiple record route with room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 23, 20, // 3 byte header + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 0, 0, 0, 0, + 0, + }, + replyOptions: []byte{ + 7, 23, 24, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 192, 168, 1, 58, // New IP Address. + 0, // padding to multiple of 4 bytes. + }, + }, + { + name: "single record route with no room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, + }, + replyOptions: []byte{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, // padding to multiple of 4 bytes. + }, + }, + { + // Unlike timestamp, this should just succeed. + name: "multiple record route with no room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 23, 24, // 3 byte header + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 0, + }, + replyOptions: []byte{ + 7, 23, 24, + 1, 2, 3, 4, + 5, 6, 7, 8, + 9, 10, 11, 12, + 13, 14, 15, 16, + 17, 18, 19, 20, + 0, // padding to multiple of 4 bytes. + }, + }, + { + // Confirm linux bug for bug compatibility. + // Linux returns slot 22 but the error is in slot 21. + name: "multiple record route with not enough room", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 8, 8, // 3 byte header + // ^ ^ Linux points here. We must too. + // | Not enough room. 1 byte free, need 4. + 1, 2, 3, 4, + 0, + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 2, + replyOptions: []byte{}, + }, + { + name: "duplicate record route", + maxTotalLength: ipv4.MaxTotalSize, + transportProtocol: uint8(header.ICMPv4ProtocolNumber), + TTL: ttl, + options: []byte{ + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 7, 7, 8, // 3 byte header + 1, 2, 3, 4, + 0, 0, // pad + }, + shouldFail: true, + expectErrorICMP: true, + ICMPType: header.ICMPv4ParamProblem, + ICMPCode: header.ICMPv4UnusedCode, + paramProblemPointer: header.IPv4MinimumSize + 7, + replyOptions: []byte{}, + }, } for _, test := range tests { t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4}, + Clock: clock, }) // We expect at most a single packet in response to our ICMP Echo Request. - e := channel.New(1, defaultMTU, "") + e := channel.New(1, ipv4.MaxTotalSize, "") if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _): %s", nicID, err) } @@ -250,6 +586,9 @@ func TestIPv4Sanity(t *testing.T) { if err := s.AddProtocolAddress(nicID, ipv4ProtoAddr); err != nil { t.Fatalf("AddProtocolAddress(%d, %#v): %s", nicID, ipv4ProtoAddr, err) } + // Advance the clock by some unimportant amount to make + // sure it's all set up. + clock.Advance(time.Millisecond * 0x10203040) // Default routes for IPv4 so ICMP can find a route to the remote // node when attempting to send the ICMP Echo Reply. @@ -312,14 +651,20 @@ func TestIPv4Sanity(t *testing.T) { reply, ok := e.Read() if !ok { if test.shouldFail { - if test.expectICMP { - t.Fatal("expected ICMP error response missing") + if test.expectErrorICMP { + t.Fatalf("ICMP error response (type %d, code %d) missing", test.ICMPType, test.ICMPCode) } return // Expected silent failure. } t.Fatal("expected ICMP echo reply missing") } + // We didn't expect a packet. Register our surprise but carry on to + // provide more information about what we got. + if test.shouldFail && !test.expectErrorICMP { + t.Error("unexpected packet response") + } + // Check the route that brought the packet to us. if reply.Route.LocalAddress != ipv4Addr.Address { t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", reply.Route.LocalAddress, ipv4Addr.Address) @@ -328,57 +673,90 @@ func TestIPv4Sanity(t *testing.T) { t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", reply.Route.RemoteAddress, remoteIPv4Addr) } - // Make sure it's all in one buffer. - vv := buffer.NewVectorisedView(reply.Pkt.Size(), reply.Pkt.Views()) - replyIPHeader := header.IPv4(vv.ToView()) + // Make sure it's all in one buffer for checker. + replyIPHeader := header.IPv4(stack.PayloadSince(reply.Pkt.NetworkHeader())) - // At this stage we only know it's an IP header so verify that much. + // At this stage we only know it's probably an IP+ICMP header so verify + // that much. checker.IPv4(t, replyIPHeader, checker.SrcAddr(ipv4Addr.Address), checker.DstAddr(remoteIPv4Addr), + checker.ICMPv4( + checker.ICMPv4Checksum(), + ), ) - // All expected responses are ICMP packets. - if got, want := replyIPHeader.Protocol(), uint8(header.ICMPv4ProtocolNumber); got != want { - t.Fatalf("not ICMP response, got protocol %d, want = %d", got, want) + // Don't proceed any further if the checker found problems. + if t.Failed() { + t.FailNow() } - replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) - // Sanity check the response. + // OK it's ICMP. We can safely look at the type now. + replyICMPHeader := header.ICMPv4(replyIPHeader.Payload()) switch replyICMPHeader.Type() { - case header.ICMPv4DstUnreachable: + case header.ICMPv4ParamProblem: + if !test.shouldFail { + t.Fatalf("got Parameter Problem with pointer %d, wanted Echo Reply", replyICMPHeader.Pointer()) + } + if !test.expectErrorICMP { + t.Fatalf("got Parameter Problem with pointer %d, wanted no response", replyICMPHeader.Pointer()) + } checker.IPv4(t, replyIPHeader, checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), checker.IPv4HeaderLength(header.IPv4MinimumSize), checker.ICMPv4( + checker.ICMPv4Type(test.ICMPType), checker.ICMPv4Code(test.ICMPCode), - checker.ICMPv4Checksum(), + checker.ICMPv4Pointer(test.paramProblemPointer), checker.ICMPv4Payload([]byte(hdr.View())), ), ) - if !test.shouldFail || !test.expectICMP { - t.Fatalf("unexpected packet rejection, got ICMP error packet type %d, code %d", + return + case header.ICMPv4DstUnreachable: + if !test.shouldFail { + t.Fatalf("got ICMP error packet type %d, code %d, wanted Echo Reply", + header.ICMPv4DstUnreachable, replyICMPHeader.Code()) + } + if !test.expectErrorICMP { + t.Fatalf("got ICMP error packet type %d, code %d, wanted no response", header.ICMPv4DstUnreachable, replyICMPHeader.Code()) } + checker.IPv4(t, replyIPHeader, + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+requestPkt.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Type(test.ICMPType), + checker.ICMPv4Code(test.ICMPCode), + checker.ICMPv4Payload([]byte(hdr.View())), + ), + ) return case header.ICMPv4EchoReply: + if test.shouldFail { + if !test.expectErrorICMP { + t.Error("got Echo Reply packet, want no response") + } else { + t.Errorf("got Echo Reply, want ICMP error type %d, code %d", test.ICMPType, test.ICMPCode) + } + } + // If the IP options change size then the packet will change size, so + // some IP header fields will need to be adjusted for the checks. + sizeChange := len(test.replyOptions) - len(test.options) + checker.IPv4(t, replyIPHeader, - checker.IPv4HeaderLength(ipHeaderLength), - checker.IPv4Options(test.options), - checker.IPFullLength(uint16(requestPkt.Size())), + checker.IPv4HeaderLength(ipHeaderLength+sizeChange), + checker.IPv4Options(test.replyOptions), + checker.IPFullLength(uint16(requestPkt.Size()+sizeChange)), checker.ICMPv4( + checker.ICMPv4Checksum(), checker.ICMPv4Code(header.ICMPv4UnusedCode), checker.ICMPv4Seq(randomSequence), checker.ICMPv4Ident(randomIdent), - checker.ICMPv4Checksum(), ), ) - if test.shouldFail { - t.Fatalf("unexpected Echo Reply packet\n") - } default: - t.Fatalf("unexpected ICMP response, got type %d, want = %d or %d", - replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable) + t.Fatalf("unexpected ICMP response, got type %d, want = %d, %d or %d", + replyICMPHeader.Type(), header.ICMPv4EchoReply, header.ICMPv4DstUnreachable, header.ICMPv4ParamProblem) } }) } @@ -462,7 +840,7 @@ var fragmentationTests = []struct { wantFragments []fragmentInfo }{ { - description: "No Fragmentation", + description: "No fragmentation", mtu: 1280, gso: nil, transportHeaderLength: 0, @@ -483,6 +861,30 @@ var fragmentationTests = []struct { }, }, { + description: "Fragmented with the minimum mtu", + mtu: header.IPv4MinimumMTU, + gso: nil, + transportHeaderLength: 0, + payloadSize: 100, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 48, more: true}, + {offset: 48, payloadSize: 48, more: true}, + {offset: 96, payloadSize: 4, more: false}, + }, + }, + { + description: "Fragmented with mtu not a multiple of 8", + mtu: header.IPv4MinimumMTU + 1, + gso: nil, + transportHeaderLength: 0, + payloadSize: 100, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 48, more: true}, + {offset: 48, payloadSize: 48, more: true}, + {offset: 96, payloadSize: 4, more: false}, + }, + }, + { description: "No fragmentation with big header", mtu: 2000, gso: nil, @@ -647,43 +1049,50 @@ func TestFragmentationWritePackets(t *testing.T) { } } -// TestFragmentationErrors checks that errors are returned from write packet +// TestFragmentationErrors checks that errors are returned from WritePacket // correctly. func TestFragmentationErrors(t *testing.T) { const ttl = 42 - expectedError := tcpip.ErrAborted - fragTests := []struct { + tests := []struct { description string mtu uint32 transportHeaderLength int payloadSize int allowPackets int - fragmentCount int + outgoingErrors int + mockError *tcpip.Error + wantError *tcpip.Error }{ { description: "No frag", mtu: 2000, - transportHeaderLength: 0, payloadSize: 1000, + transportHeaderLength: 0, allowPackets: 0, - fragmentCount: 1, + outgoingErrors: 1, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, }, { description: "Error on first frag", mtu: 500, - transportHeaderLength: 0, payloadSize: 1000, + transportHeaderLength: 0, allowPackets: 0, - fragmentCount: 3, + outgoingErrors: 3, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, }, { description: "Error on second frag", mtu: 500, - transportHeaderLength: 0, payloadSize: 1000, + transportHeaderLength: 0, allowPackets: 1, - fragmentCount: 3, + outgoingErrors: 2, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, }, { description: "Error on first frag MTU smaller than header", @@ -691,28 +1100,40 @@ func TestFragmentationErrors(t *testing.T) { transportHeaderLength: 1000, payloadSize: 500, allowPackets: 0, - fragmentCount: 4, + outgoingErrors: 4, + mockError: tcpip.ErrAborted, + wantError: tcpip.ErrAborted, + }, + { + description: "Error when MTU is smaller than IPv4 minimum MTU", + mtu: header.IPv4MinimumMTU - 1, + transportHeaderLength: 0, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 1, + mockError: nil, + wantError: tcpip.ErrInvalidEndpointState, }, } - for _, ft := range fragTests { + for _, ft := range tests { t.Run(ft.description, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(ft.mtu, expectedError, ft.allowPackets) - r := buildRoute(t, ep) pkt := testutil.MakeRandPkt(ft.transportHeaderLength, extraHeaderReserve+header.IPv4MinimumSize, []int{ft.payloadSize}, header.IPv4ProtocolNumber) + ep := testutil.NewMockLinkEndpoint(ft.mtu, ft.mockError, ft.allowPackets) + r := buildRoute(t, ep) err := r.WritePacket(&stack.GSO{}, stack.NetworkHeaderParams{ Protocol: tcp.ProtocolNumber, TTL: ttl, TOS: stack.DefaultTOS, }, pkt) - if err != expectedError { - t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, expectedError) + if err != ft.wantError { + t.Errorf("got WritePacket(_, _, _) = %s, want = %s", err, ft.wantError) } - if got, want := len(ep.WrittenPackets), int(r.Stats().IP.PacketsSent.Value()); err != nil && got != want { - t.Errorf("got len(ep.WrittenPackets) = %d, want = %d", got, want) + if got := int(r.Stats().IP.PacketsSent.Value()); got != ft.allowPackets { + t.Errorf("got r.Stats().IP.PacketsSent.Value() = %d, want = %d", got, ft.allowPackets) } - if got, want := int(r.Stats().IP.OutgoingPacketErrors.Value()), ft.fragmentCount-ft.allowPackets; got != want { - t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, want) + if got := int(r.Stats().IP.OutgoingPacketErrors.Value()); got != ft.outgoingErrors { + t.Errorf("got r.Stats().IP.OutgoingPacketErrors.Value() = %d, want = %d", got, ft.outgoingErrors) } }) } @@ -744,7 +1165,6 @@ func TestInvalidFragments(t *testing.T) { autoChecksum bool // if true, the Checksum field will be overwritten. } - // These packets have both IHL and TotalLength set to 0. tests := []struct { name string fragments []fragmentData @@ -984,7 +1404,6 @@ func TestInvalidFragments(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{ ipv4.NewProtocol, @@ -1027,6 +1446,259 @@ func TestInvalidFragments(t *testing.T) { } } +func TestFragmentReassemblyTimeout(t *testing.T) { + const ( + nicID = 1 + linkAddr = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + addr1 = "\x0a\x00\x00\x01" + addr2 = "\x0a\x00\x00\x02" + tos = 0 + ident = 1 + ttl = 48 + protocol = 99 + data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" + ) + + type fragmentData struct { + ipv4fields header.IPv4Fields + payload []byte + } + + tests := []struct { + name string + fragments []fragmentData + expectICMP bool + }{ + { + name: "first fragment only", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "two first fragments", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 16, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "second fragment only", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 8, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: false, + }, + { + name: "two fragments with a gap", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:8], + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 16, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: true, + }, + { + name: "two fragments with a gap in reverse order", + fragments: []fragmentData{ + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: uint16(header.IPv4MinimumSize + len(data) - 16), + ID: ident, + Flags: 0, + FragmentOffset: 16, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[16:], + }, + { + ipv4fields: header.IPv4Fields{ + IHL: header.IPv4MinimumSize, + TOS: tos, + TotalLength: header.IPv4MinimumSize + 8, + ID: ident, + Flags: header.IPv4FlagMoreFragments, + FragmentOffset: 0, + TTL: ttl, + Protocol: protocol, + SrcAddr: addr1, + DstAddr: addr2, + }, + payload: []byte(data)[:8], + }, + }, + expectICMP: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + ipv4.NewProtocol, + }, + Clock: clock, + }) + e := channel.New(1, 1500, linkAddr) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ipv4.ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv4ProtocolNumber, addr2, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv4EmptySubnet, + NIC: nicID, + }}) + + var firstFragmentSent buffer.View + for _, f := range test.fragments { + pktSize := header.IPv4MinimumSize + hdr := buffer.NewPrependable(pktSize) + + ip := header.IPv4(hdr.Prepend(pktSize)) + ip.Encode(&f.ipv4fields) + + ip.SetChecksum(0) + ip.SetChecksum(^ip.CalculateChecksum()) + + vv := hdr.View().ToVectorisedView() + vv.AppendView(f.payload) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + + if firstFragmentSent == nil && ip.FragmentOffset() == 0 { + firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(header.IPv4ProtocolNumber, pkt) + } + + clock.Advance(ipv4.ReassembleTimeout) + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + if firstFragmentSent == nil { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + + checker.IPv4(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv4MinimumSize+header.ICMPv4MinimumSize+firstFragmentSent.Size())), + checker.IPv4HeaderLength(header.IPv4MinimumSize), + checker.ICMPv4( + checker.ICMPv4Type(header.ICMPv4TimeExceeded), + checker.ICMPv4Code(header.ICMPv4ReassemblyTimeout), + checker.ICMPv4Checksum(), + checker.ICMPv4Payload([]byte(firstFragmentSent)), + ), + ) + }) + } +} + // TestReceiveFragments feeds fragments in through the incoming packet path to // test reassembly func TestReceiveFragments(t *testing.T) { @@ -1506,13 +2178,10 @@ func TestWriteStats(t *testing.T) { // Install Output DROP rule. t.Helper() ipt := stk.IPTables() - filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) - if !ok { - t.Fatalf("failed to find filter table") - } + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %s", err) } }, @@ -1527,17 +2196,14 @@ func TestWriteStats(t *testing.T) { // of the 3 packets. t.Helper() ipt := stk.IPTables() - filter, ok := ipt.GetTable(stack.FilterTable, false /* ipv6 */) - if !ok { - t.Fatalf("failed to find filter table") - } + filter := ipt.GetTable(stack.FilterID, false /* ipv6 */) // We'll match and DROP the last packet. ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} // Make sure the next rule is ACCEPT. filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterTable, filter, false /* ipv6 */); err != nil { + if err := ipt.ReplaceTable(stack.FilterID, filter, false /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %s", err) } }, @@ -1577,7 +2243,7 @@ func TestWriteStats(t *testing.T) { t.Run(writer.name, func(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumSize+header.UDPMinimumSize, tcpip.ErrInvalidEndpointState, test.allowPackets) + ep := testutil.NewMockLinkEndpoint(header.IPv4MinimumMTU, tcpip.ErrInvalidEndpointState, test.allowPackets) rt := buildRoute(t, ep) var pkts stack.PacketBufferList @@ -1783,7 +2449,7 @@ func TestPacketQueing(t *testing.T) { for _, test := range tests { t.Run(test.name, func(t *testing.T) { - e := channel.New(1, header.IPv6MinimumMTU, host1NICLinkAddr) + e := channel.New(1, defaultMTU, host1NICLinkAddr) e.LinkEPCapabilities |= stack.CapabilityResolutionRequired s := stack.New(stack.Options{ NetworkProtocols: []stack.NetworkProtocolFactory{arp.NewProtocol, ipv4.NewProtocol}, diff --git a/pkg/tcpip/network/ipv6/BUILD b/pkg/tcpip/network/ipv6/BUILD index a30437f02..0ac24a6fb 100644 --- a/pkg/tcpip/network/ipv6/BUILD +++ b/pkg/tcpip/network/ipv6/BUILD @@ -36,6 +36,7 @@ go_test( "//pkg/tcpip", "//pkg/tcpip/buffer", "//pkg/tcpip/checker", + "//pkg/tcpip/faketime", "//pkg/tcpip/header", "//pkg/tcpip/link/channel", "//pkg/tcpip/link/sniffer", diff --git a/pkg/tcpip/network/ipv6/icmp.go b/pkg/tcpip/network/ipv6/icmp.go index ead6bedcb..8502b848c 100644 --- a/pkg/tcpip/network/ipv6/icmp.go +++ b/pkg/tcpip/network/ipv6/icmp.go @@ -124,8 +124,8 @@ func getTargetLinkAddr(it header.NDPOptionIterator) (tcpip.LinkAddress, bool) { }) } -func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragmentHeader bool) { - stats := r.Stats().ICMP +func (e *endpoint) handleICMP(pkt *stack.PacketBuffer, hasFragmentHeader bool) { + stats := e.protocol.stack.Stats().ICMP sent := stats.V6PacketsSent received := stats.V6PacketsReceived // TODO(gvisor.dev/issue/170): ICMP packets don't have their @@ -138,13 +138,15 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme } h := header.ICMPv6(v) iph := header.IPv6(pkt.NetworkHeader().View()) + srcAddr := iph.SourceAddress() + dstAddr := iph.DestinationAddress() // Validate ICMPv6 checksum before processing the packet. // // This copy is used as extra payload during the checksum calculation. payload := pkt.Data.Clone(nil) payload.TrimFront(len(h)) - if got, want := h.Checksum(), header.ICMPv6Checksum(h, iph.SourceAddress(), iph.DestinationAddress(), payload); got != want { + if got, want := h.Checksum(), header.ICMPv6Checksum(h, srcAddr, dstAddr, payload); got != want { received.Invalid.Increment() return } @@ -170,8 +172,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } pkt.Data.TrimFront(header.ICMPv6PacketTooBigMinimumSize) - mtu := header.ICMPv6(hdr).MTU() - e.handleControl(stack.ControlPacketTooBig, calculateMTU(mtu), pkt) + networkMTU, err := calculateNetworkMTU(header.ICMPv6(hdr).MTU(), header.IPv6MinimumSize) + if err != nil { + networkMTU = 0 + } + e.handleControl(stack.ControlPacketTooBig, networkMTU, pkt) case header.ICMPv6DstUnreachable: received.DstUnreachable.Increment() @@ -221,7 +226,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // we know we are also performing DAD on it). In this case we let the // stack know so it can handle such a scenario and do nothing further with // the NS. - if r.RemoteAddress == header.IPv6Any { + if srcAddr == header.IPv6Any { // We would get an error if the address no longer exists or the address // is no longer tentative (DAD resolved between the call to // hasTentativeAddr and this point). Both of these are valid scenarios: @@ -248,7 +253,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // section 5.4.3. // Is the NS targeting us? - if r.Stack().CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { + if e.protocol.stack.CheckLocalAddress(e.nic.ID(), ProtocolNumber, targetAddr) == 0 { return } @@ -274,9 +279,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // Otherwise, on link layers that have addresses this option MUST be // included in multicast solicitations and SHOULD be included in unicast // solicitations. - unspecifiedSource := r.RemoteAddress == header.IPv6Any + unspecifiedSource := srcAddr == header.IPv6Any if len(sourceLinkAddr) == 0 { - if header.IsV6MulticastAddress(r.LocalAddress) && !unspecifiedSource { + if header.IsV6MulticastAddress(dstAddr) && !unspecifiedSource { received.Invalid.Increment() return } @@ -284,9 +289,9 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme received.Invalid.Increment() return } else if e.nud != nil { - e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(srcAddr, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) } else { - e.linkAddrCache.AddLinkAddress(e.nic.ID(), r.RemoteAddress, sourceLinkAddr) + e.linkAddrCache.AddLinkAddress(e.nic.ID(), srcAddr, sourceLinkAddr) } // As per RFC 4861 section 7.1.1: @@ -295,7 +300,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // ... // - If the IP source address is the unspecified address, the IP // destination address is a solicited-node multicast address. - if unspecifiedSource && !header.IsSolicitedNodeAddr(r.LocalAddress) { + if unspecifiedSource && !header.IsSolicitedNodeAddr(dstAddr) { received.Invalid.Increment() return } @@ -305,7 +310,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // If the source of the solicitation is the unspecified address, the node // MUST [...] and multicast the advertisement to the all-nodes address. // - remoteAddr := r.RemoteAddress + remoteAddr := srcAddr if unspecifiedSource { remoteAddr = header.IPv6AllNodesMulticastAddress } @@ -462,12 +467,12 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // As per RFC 4291 section 2.7, multicast addresses must not be used as // source addresses in IPv6 packets. - localAddr := r.LocalAddress - if header.IsV6MulticastAddress(r.LocalAddress) { + localAddr := dstAddr + if header.IsV6MulticastAddress(dstAddr) { localAddr = "" } - r, err := r.Stack().FindRoute(e.nic.ID(), localAddr, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + r, err := e.protocol.stack.FindRoute(e.nic.ID(), localAddr, srcAddr, ProtocolNumber, false /* multicastLoop */) if err != nil { // If we cannot find a route to the destination, silently drop the packet. return @@ -483,7 +488,11 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme copy(packet, icmpHdr) packet.SetType(header.ICMPv6EchoReply) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, pkt.Data)) - if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{Protocol: header.ICMPv6ProtocolNumber, TTL: r.DefaultTTL(), TOS: stack.DefaultTOS}, replyPkt); err != nil { + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: r.DefaultTTL(), + TOS: stack.DefaultTOS, + }, replyPkt); err != nil { sent.Dropped.Increment() return } @@ -495,7 +504,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme received.Invalid.Increment() return } - e.dispatcher.DeliverTransportPacket(r, header.ICMPv6ProtocolNumber, pkt) + e.dispatcher.DeliverTransportPacket(header.ICMPv6ProtocolNumber, pkt) case header.ICMPv6TimeExceeded: received.TimeExceeded.Increment() @@ -516,7 +525,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - stack := r.Stack() + stack := e.protocol.stack // Is the networking stack operating as a router? if !stack.Forwarding(ProtocolNumber) { @@ -547,7 +556,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // As per RFC 4861 section 4.1, the Source Link-Layer Address Option MUST // NOT be included when the source IP address is the unspecified address. // Otherwise, it SHOULD be included on link layers that have addresses. - if r.RemoteAddress == header.IPv6Any { + if srcAddr == header.IPv6Any { received.Invalid.Increment() return } @@ -555,7 +564,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme if e.nud != nil { // A RS with a specified source IP address modifies the NUD state // machine in the same way a reachability probe would. - e.nud.HandleProbe(r.RemoteAddress, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(srcAddr, ProtocolNumber, sourceLinkAddr, e.protocol) } } @@ -572,7 +581,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme return } - routerAddr := iph.SourceAddress() + routerAddr := srcAddr // Is the IP Source Address a link-local address? if !header.IsV6LinkLocalAddress(routerAddr) { @@ -605,7 +614,7 @@ func (e *endpoint) handleICMP(r *stack.Route, pkt *stack.PacketBuffer, hasFragme // If the RA has the source link layer option, update the link address // cache with the link address for the advertised router. if len(sourceLinkAddr) != 0 && e.nud != nil { - e.nud.HandleProbe(routerAddr, r.LocalAddress, header.IPv6ProtocolNumber, sourceLinkAddr, e.protocol) + e.nud.HandleProbe(routerAddr, ProtocolNumber, sourceLinkAddr, e.protocol) } e.mu.Lock() @@ -648,52 +657,46 @@ func (*protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements stack.LinkAddressResolver. -func (*protocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP stack.LinkEndpoint) *tcpip.Error { - // TODO(b/148672031): Use stack.FindRoute instead of manually creating the - // route here. Note, we would need the nicID to do this properly so the right - // NIC (associated to linkEP) is used to send the NDP NS message. - r := stack.Route{ - LocalAddress: localAddr, - RemoteAddress: addr, - LocalLinkAddress: linkEP.LinkAddress(), - RemoteLinkAddress: remoteLinkAddr, +func (p *protocol) LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic stack.NetworkInterface) *tcpip.Error { + remoteAddr := targetAddr + if len(remoteLinkAddr) == 0 { + remoteAddr = header.SolicitedNodeAddr(targetAddr) + remoteLinkAddr = header.EthernetAddressFromMulticastIPv6Address(remoteAddr) } - // If a remote address is not already known, then send a multicast - // solicitation since multicast addresses have a static mapping to link - // addresses. - if len(r.RemoteLinkAddress) == 0 { - r.RemoteAddress = header.SolicitedNodeAddr(addr) - r.RemoteLinkAddress = header.EthernetAddressFromMulticastIPv6Address(r.RemoteAddress) + r, err := p.stack.FindRoute(nic.ID(), localAddr, remoteAddr, ProtocolNumber, false /* multicastLoop */) + if err != nil { + return err } + defer r.Release() + r.ResolveWith(remoteLinkAddr) optsSerializer := header.NDPOptionsSerializer{ - header.NDPSourceLinkLayerAddressOption(linkEP.LinkAddress()), + header.NDPSourceLinkLayerAddressOption(nic.LinkAddress()), } neighborSolicitSize := header.ICMPv6NeighborSolicitMinimumSize + optsSerializer.Length() pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ - ReserveHeaderBytes: int(linkEP.MaxHeaderLength()) + header.IPv6MinimumSize + neighborSolicitSize, + ReserveHeaderBytes: int(r.MaxHeaderLength()) + neighborSolicitSize, }) pkt.TransportProtocolNumber = header.ICMPv6ProtocolNumber packet := header.ICMPv6(pkt.TransportHeader().Push(neighborSolicitSize)) packet.SetType(header.ICMPv6NeighborSolicit) ns := header.NDPNeighborSolicit(packet.NDPPayload()) - ns.SetTargetAddress(addr) + ns.SetTargetAddress(targetAddr) ns.Options().Serialize(optsSerializer) packet.SetChecksum(header.ICMPv6Checksum(packet, r.LocalAddress, r.RemoteAddress, buffer.VectorisedView{})) - length := uint16(pkt.Size()) - ip := header.IPv6(pkt.NetworkHeader().Push(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: length, - NextHeader: uint8(header.ICMPv6ProtocolNumber), - HopLimit: header.NDPHopLimit, - SrcAddr: r.LocalAddress, - DstAddr: r.RemoteAddress, - }) + stat := p.stack.Stats().ICMP.V6PacketsSent + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: header.ICMPv6ProtocolNumber, + TTL: header.NDPHopLimit, + }, pkt); err != nil { + stat.Dropped.Increment() + return err + } - // TODO(stijlist): count this in ICMP stats. - return linkEP.WritePacket(&r, nil /* gso */, ProtocolNumber, pkt) + stat.NeighborSolicit.Increment() + return nil } // ResolveStaticAddress implements stack.LinkAddressResolver. @@ -747,9 +750,20 @@ type icmpReasonPortUnreachable struct{} func (*icmpReasonPortUnreachable) isICMPReason() {} +// icmpReasonReassemblyTimeout is an error where insufficient fragments are +// received to complete reassembly of a packet within a configured time after +// the reception of the first-arriving fragment of that packet. +type icmpReasonReassemblyTimeout struct{} + +func (*icmpReasonReassemblyTimeout) isICMPReason() {} + // returnError takes an error descriptor and generates the appropriate ICMP // error packet for IPv6 and sends it. -func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { +func (p *protocol) returnError(reason icmpReason, pkt *stack.PacketBuffer) *tcpip.Error { + origIPHdr := header.IPv6(pkt.NetworkHeader().View()) + origIPHdrSrc := origIPHdr.SourceAddress() + origIPHdrDst := origIPHdr.DestinationAddress() + // Only send ICMP error if the address is not a multicast v6 // address and the source is not the unspecified address. // @@ -776,7 +790,7 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac allowResponseToMulticast = reason.respondToMulticast } - if (!allowResponseToMulticast && header.IsV6MulticastAddress(r.LocalAddress)) || r.RemoteAddress == header.IPv6Any { + if (!allowResponseToMulticast && header.IsV6MulticastAddress(origIPHdrDst)) || origIPHdrSrc == header.IPv6Any { return nil } @@ -784,14 +798,11 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac // a route to it - the remote may be blocked via routing rules. We must always // consult our routing table and find a route to the remote before sending any // packet. - route, err := p.stack.FindRoute(r.NICID(), r.LocalAddress, r.RemoteAddress, ProtocolNumber, false /* multicastLoop */) + route, err := p.stack.FindRoute(pkt.NICID, origIPHdrDst, origIPHdrSrc, ProtocolNumber, false /* multicastLoop */) if err != nil { return err } defer route.Release() - // From this point on, the incoming route should no longer be used; route - // must be used to send the ICMP error. - r = nil stats := p.stack.Stats().ICMP sent := stats.V6PacketsSent @@ -839,7 +850,9 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac if payloadLen > available { payloadLen = available } - payload := buffer.NewVectorisedView(pkt.Size(), pkt.Views()) + payload := network.ToVectorisedView() + payload.AppendView(transport) + payload.Append(pkt.Data) payload.CapLength(payloadLen) newPkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ @@ -860,6 +873,10 @@ func (p *protocol) returnError(r *stack.Route, reason icmpReason, pkt *stack.Pac icmpHdr.SetType(header.ICMPv6DstUnreachable) icmpHdr.SetCode(header.ICMPv6PortUnreachable) counter = sent.DstUnreachable + case *icmpReasonReassemblyTimeout: + icmpHdr.SetType(header.ICMPv6TimeExceeded) + icmpHdr.SetCode(header.ICMPv6ReassemblyTimeout) + counter = sent.TimeExceeded default: panic(fmt.Sprintf("unsupported ICMP type %T", reason)) } diff --git a/pkg/tcpip/network/ipv6/icmp_test.go b/pkg/tcpip/network/ipv6/icmp_test.go index 8dc33c560..76013daa1 100644 --- a/pkg/tcpip/network/ipv6/icmp_test.go +++ b/pkg/tcpip/network/ipv6/icmp_test.go @@ -51,6 +51,7 @@ const ( var ( lladdr0 = header.LinkLocalAddr(linkAddr0) lladdr1 = header.LinkLocalAddr(linkAddr1) + lladdr2 = header.LinkLocalAddr(linkAddr2) ) type stubLinkEndpoint struct { @@ -86,7 +87,7 @@ type stubDispatcher struct { stack.TransportDispatcher } -func (*stubDispatcher) DeliverTransportPacket(*stack.Route, tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition { +func (*stubDispatcher) DeliverTransportPacket(tcpip.TransportProtocolNumber, *stack.PacketBuffer) stack.TransportPacketDisposition { return stack.TransportPacketHandled } @@ -108,31 +109,27 @@ type stubNUDHandler struct { var _ stack.NUDHandler = (*stubNUDHandler)(nil) -func (s *stubNUDHandler) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes stack.LinkAddressResolver) { +func (s *stubNUDHandler) HandleProbe(tcpip.Address, tcpip.NetworkProtocolNumber, tcpip.LinkAddress, stack.LinkAddressResolver) { s.probeCount++ } -func (s *stubNUDHandler) HandleConfirmation(addr tcpip.Address, linkAddr tcpip.LinkAddress, flags stack.ReachabilityConfirmationFlags) { +func (s *stubNUDHandler) HandleConfirmation(tcpip.Address, tcpip.LinkAddress, stack.ReachabilityConfirmationFlags) { s.confirmationCount++ } -func (*stubNUDHandler) HandleUpperLevelConfirmation(addr tcpip.Address) { +func (*stubNUDHandler) HandleUpperLevelConfirmation(tcpip.Address) { } var _ stack.NetworkInterface = (*testInterface)(nil) type testInterface struct { - stack.NetworkLinkEndpoint - - linkAddr tcpip.LinkAddress -} + stack.LinkEndpoint -func (i *testInterface) LinkAddress() tcpip.LinkAddress { - return i.linkAddr + nicID tcpip.NICID } func (*testInterface) ID() tcpip.NICID { - return 0 + return nicID } func (*testInterface) IsLoopback() bool { @@ -147,6 +144,14 @@ func (*testInterface) Enabled() bool { return true } +func (t *testInterface) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *stack.GSO, protocol tcpip.NetworkProtocolNumber, pkt *stack.PacketBuffer) *tcpip.Error { + r := stack.Route{ + NetProto: protocol, + RemoteLinkAddress: remoteLinkAddr, + } + return t.LinkEndpoint.WritePacket(&r, gso, protocol, pkt) +} + func TestICMPCounts(t *testing.T) { tests := []struct { name string @@ -277,7 +282,8 @@ func TestICMPCounts(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) } for _, typ := range types { @@ -419,7 +425,8 @@ func TestICMPCountsWithNeighborCache(t *testing.T) { SrcAddr: r.LocalAddress, DstAddr: r.RemoteAddress, }) - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) } for _, typ := range types { @@ -1235,26 +1242,72 @@ func TestICMPChecksumValidationWithPayloadMultipleViews(t *testing.T) { } func TestLinkAddressRequest(t *testing.T) { + const nicID = 1 + snaddr := header.SolicitedNodeAddr(lladdr0) mcaddr := header.EthernetAddressFromMulticastIPv6Address(snaddr) tests := []struct { - name string - remoteLinkAddr tcpip.LinkAddress - expectedLinkAddr tcpip.LinkAddress - expectedAddr tcpip.Address + name string + nicAddr tcpip.Address + localAddr tcpip.Address + remoteLinkAddr tcpip.LinkAddress + + expectedErr *tcpip.Error + expectedRemoteAddr tcpip.Address + expectedRemoteLinkAddr tcpip.LinkAddress }{ { - name: "Unicast", - remoteLinkAddr: linkAddr1, - expectedLinkAddr: linkAddr1, - expectedAddr: lladdr0, + name: "Unicast", + nicAddr: lladdr1, + localAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedRemoteAddr: lladdr0, + expectedRemoteLinkAddr: linkAddr1, + }, + { + name: "Multicast", + nicAddr: lladdr1, + localAddr: lladdr1, + remoteLinkAddr: "", + expectedRemoteAddr: snaddr, + expectedRemoteLinkAddr: mcaddr, + }, + { + name: "Unicast with unspecified source", + nicAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedRemoteAddr: lladdr0, + expectedRemoteLinkAddr: linkAddr1, }, { - name: "Multicast", - remoteLinkAddr: "", - expectedLinkAddr: mcaddr, - expectedAddr: snaddr, + name: "Multicast with unspecified source", + nicAddr: lladdr1, + remoteLinkAddr: "", + expectedRemoteAddr: snaddr, + expectedRemoteLinkAddr: mcaddr, + }, + { + name: "Unicast with unassigned address", + localAddr: lladdr1, + remoteLinkAddr: linkAddr1, + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Multicast with unassigned address", + localAddr: lladdr1, + remoteLinkAddr: "", + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Unicast with no local address available", + remoteLinkAddr: linkAddr1, + expectedErr: tcpip.ErrNetworkUnreachable, + }, + { + name: "Multicast with no local address available", + remoteLinkAddr: "", + expectedErr: tcpip.ErrNetworkUnreachable, }, } @@ -1269,26 +1322,43 @@ func TestLinkAddressRequest(t *testing.T) { } linkEP := channel.New(defaultChannelSize, defaultMTU, linkAddr0) - if err := linkRes.LinkAddressRequest(lladdr0, lladdr1, test.remoteLinkAddr, linkEP); err != nil { - t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s", lladdr0, lladdr1, test.remoteLinkAddr, err) + if err := s.CreateNIC(nicID, linkEP); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if len(test.nicAddr) != 0 { + if err := s.AddAddress(nicID, ProtocolNumber, test.nicAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, ProtocolNumber, test.nicAddr, err) + } + } + + // We pass a test network interface to LinkAddressRequest with the same NIC + // ID and link endpoint used by the NIC we created earlier so that we can + // mock a link address request and observe the packets sent to the link + // endpoint even though the stack uses the real NIC. + if err := linkRes.LinkAddressRequest(lladdr0, test.localAddr, test.remoteLinkAddr, &testInterface{LinkEndpoint: linkEP, nicID: nicID}); err != test.expectedErr { + t.Errorf("got p.LinkAddressRequest(%s, %s, %s, _) = %s, want = %s", lladdr0, test.localAddr, test.remoteLinkAddr, err, test.expectedErr) + } + + if test.expectedErr != nil { + return } pkt, ok := linkEP.Read() if !ok { t.Fatal("expected to send a link address request") } - if pkt.Route.RemoteLinkAddress != test.expectedLinkAddr { - t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedLinkAddr) + if pkt.Route.RemoteLinkAddress != test.expectedRemoteLinkAddr { + t.Errorf("got pkt.Route.RemoteLinkAddress = %s, want = %s", pkt.Route.RemoteLinkAddress, test.expectedRemoteLinkAddr) } - if pkt.Route.RemoteAddress != test.expectedAddr { - t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedAddr) + if pkt.Route.RemoteAddress != test.expectedRemoteAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.expectedRemoteAddr) } if pkt.Route.LocalAddress != lladdr1 { t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, lladdr1) } checker.IPv6(t, stack.PayloadSince(pkt.Pkt.NetworkHeader()), checker.SrcAddr(lladdr1), - checker.DstAddr(test.expectedAddr), + checker.DstAddr(test.expectedRemoteAddr), checker.TTL(header.NDPHopLimit), checker.NDPNS( checker.NDPNSTargetAddress(lladdr0), @@ -1698,7 +1768,7 @@ func TestCallsToNeighborCache(t *testing.T) { t.Fatalf("cannot find protocol instance for network protocol %d", ProtocolNumber) } nudHandler := &stubNUDHandler{} - ep := netProto.NewEndpoint(&testInterface{linkAddr: linkAddr0}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) + ep := netProto.NewEndpoint(&testInterface{LinkEndpoint: channel.New(0, header.IPv6MinimumMTU, linkAddr0)}, &stubLinkAddressCache{}, nudHandler, &stubDispatcher{}) defer ep.Close() if err := ep.Enable(); err != nil { @@ -1728,7 +1798,8 @@ func TestCallsToNeighborCache(t *testing.T) { SrcAddr: r.RemoteAddress, DstAddr: r.LocalAddress, }) - ep.HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) // Confirm the endpoint calls the correct NUDHandler method. if nudHandler.probeCount != test.wantProbeCount { diff --git a/pkg/tcpip/network/ipv6/ipv6.go b/pkg/tcpip/network/ipv6/ipv6.go index 9670696c7..0526190cc 100644 --- a/pkg/tcpip/network/ipv6/ipv6.go +++ b/pkg/tcpip/network/ipv6/ipv6.go @@ -41,12 +41,12 @@ const ( // // Linux also uses 60 seconds for reassembly timeout: // https://github.com/torvalds/linux/blob/47ec5303d73ea344e84f46660fff693c57641386/include/net/ipv6.h#L456 - reassembleTimeout = 60 * time.Second + ReassembleTimeout = 60 * time.Second // ProtocolNumber is the ipv6 protocol number. ProtocolNumber = header.IPv6ProtocolNumber - // maxTotalSize is maximum size that can be encoded in the 16-bit + // maxPayloadSize is the maximum size that can be encoded in the 16-bit // PayloadLength field of the ipv6 header. maxPayloadSize = 0xffff @@ -166,7 +166,7 @@ func (e *endpoint) dupTentativeAddrDetected(addr tcpip.Address) *tcpip.Error { return err } - prefix := addressEndpoint.AddressWithPrefix().Subnet() + prefix := addressEndpoint.Subnet() switch t := addressEndpoint.ConfigType(); t { case stack.AddressConfigStatic: @@ -363,7 +363,11 @@ func (e *endpoint) DefaultTTL() uint8 { // MTU implements stack.NetworkEndpoint.MTU. It returns the link-layer MTU minus // the network layer max header length. func (e *endpoint) MTU() uint32 { - return calculateMTU(e.nic.MTU()) + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), header.IPv6MinimumSize) + if err != nil { + return 0 + } + return networkMTU } // MaxHeaderLength returns the maximum length needed by ipv6 headers (and @@ -386,27 +390,40 @@ func (e *endpoint) addIPHeader(r *stack.Route, pkt *stack.PacketBuffer, params s pkt.NetworkProtocolNumber = ProtocolNumber } -func (e *endpoint) packetMustBeFragmented(pkt *stack.PacketBuffer, gso *stack.GSO) bool { - return (gso == nil || gso.Type == stack.GSONone) && pkt.Size() > int(e.nic.MTU()) +func packetMustBeFragmented(pkt *stack.PacketBuffer, networkMTU uint32, gso *stack.GSO) bool { + payload := pkt.TransportHeader().View().Size() + pkt.Data.Size() + return (gso == nil || gso.Type == stack.GSONone) && uint32(payload) > networkMTU } // handleFragments fragments pkt and calls the handler function on each // fragment. It returns the number of fragments handled and the number of // fragments left to be processed. The IP header must already be present in the -// original packet. The mtu is the maximum size of the packets. The transport -// header protocol number is required to avoid parsing the IPv6 extension -// headers. -func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, mtu uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { - fragMTU := int(calculateFragmentInnerMTU(mtu, pkt)) - if fragMTU < pkt.TransportHeader().View().Size() { +// original packet. The transport header protocol number is required to avoid +// parsing the IPv6 extension headers. +func (e *endpoint) handleFragments(r *stack.Route, gso *stack.GSO, networkMTU uint32, pkt *stack.PacketBuffer, transProto tcpip.TransportProtocolNumber, handler func(*stack.PacketBuffer) *tcpip.Error) (int, int, *tcpip.Error) { + networkHeader := header.IPv6(pkt.NetworkHeader().View()) + + // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are + // supported for outbound packets, their length should not affect the fragment + // maximum payload length because they should only be transmitted once. + fragmentPayloadLen := (networkMTU - header.IPv6FragmentHeaderSize) &^ 7 + if fragmentPayloadLen < header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit { + // We need at least 8 bytes of space left for the fragmentable part because + // the fragment payload must obviously be non-zero and must be a multiple + // of 8 as per RFC 8200 section 4.5: + // Each complete fragment, except possibly the last ("rightmost") one, is + // an integer multiple of 8 octets long. + return 0, 1, tcpip.ErrMessageTooLong + } + + if fragmentPayloadLen < uint32(pkt.TransportHeader().View().Size()) { // As per RFC 8200 Section 4.5, the Transport Header is expected to be small // enough to fit in the first fragment. return 0, 1, tcpip.ErrMessageTooLong } - pf := fragmentation.MakePacketFragmenter(pkt, fragMTU, calculateFragmentReserve(pkt)) + pf := fragmentation.MakePacketFragmenter(pkt, fragmentPayloadLen, calculateFragmentReserve(pkt)) id := atomic.AddUint32(&e.protocol.ids[hashRoute(r, e.protocol.hashIV)%buckets], 1) - networkHeader := header.IPv6(pkt.NetworkHeader().View()) var n int for { @@ -448,28 +465,40 @@ func (e *endpoint) writePacket(r *stack.Route, gso *stack.GSO, pkt *stack.Packet if pkt.NatDone { netHeader := header.IPv6(pkt.NetworkHeader().View()) if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + route.PopulatePacketInfo(pkt) + // Since we rewrote the packet but it is being routed back to us, we can + // safely assume the checksum is valid. + pkt.RXTransportChecksumValidated = true + ep.HandlePacket(pkt) + } return nil } } if r.Loop&stack.PacketLoop != 0 { - loopedR := r.MakeLoopedRoute() - - e.HandlePacket(&loopedR, stack.NewPacketBuffer(stack.PacketBufferOptions{ - // The inbound path expects an unparsed packet. - Data: buffer.NewVectorisedView(pkt.Size(), pkt.Views()), - })) - - loopedR.Release() + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + loopedR := r.MakeLoopedRoute() + loopedR.PopulatePacketInfo(pkt) + loopedR.Release() + e.HandlePacket(pkt) + } } if r.Loop&stack.PacketOut == 0 { return nil } - if e.packetMustBeFragmented(pkt, gso) { - sent, remain, err := e.handleFragments(r, gso, e.nic.MTU(), pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + networkMTU, err := calculateNetworkMTU(e.nic.MTU(), uint32(pkt.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.Increment() + return err + } + + if packetMustBeFragmented(pkt, networkMTU, gso) { + sent, remain, err := e.handleFragments(r, gso, networkMTU, pkt, protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // TODO(gvisor.dev/issue/3884): Evaluate whether we want to send each // fragment one by one using WritePacket() (current strategy) or if we // want to create a PacketBufferList from the fragments and feed it to @@ -499,13 +528,20 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return pkts.Len(), nil } + linkMTU := e.nic.MTU() for pb := pkts.Front(); pb != nil; pb = pb.Next() { e.addIPHeader(r, pb, params) - if e.packetMustBeFragmented(pb, gso) { + + networkMTU, err := calculateNetworkMTU(linkMTU, uint32(pb.NetworkHeader().View().Size())) + if err != nil { + r.Stats().IP.OutgoingPacketErrors.IncrementBy(uint64(pkts.Len())) + return 0, err + } + if packetMustBeFragmented(pb, networkMTU, gso) { // Keep track of the packet that is about to be fragmented so it can be // removed once the fragmentation is done. originalPkt := pb - if _, _, err := e.handleFragments(r, gso, e.nic.MTU(), pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { + if _, _, err := e.handleFragments(r, gso, networkMTU, pb, params.Protocol, func(fragPkt *stack.PacketBuffer) *tcpip.Error { // Modify the packet list in place with the new fragments. pkts.InsertAfter(pb, fragPkt) pb = fragPkt @@ -546,10 +582,12 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe if _, ok := natPkts[pkt]; ok { netHeader := header.IPv6(pkt.NetworkHeader().View()) if ep, err := e.protocol.stack.FindNetworkEndpoint(ProtocolNumber, netHeader.DestinationAddress()); err == nil { - src := netHeader.SourceAddress() - dst := netHeader.DestinationAddress() - route := r.ReverseRoute(src, dst) - ep.HandlePacket(&route, pkt) + pkt := pkt.CloneToInbound() + if e.protocol.stack.ParsePacketBuffer(ProtocolNumber, pkt) == stack.ParsedOK { + route := r.ReverseRoute(netHeader.SourceAddress(), netHeader.DestinationAddress()) + route.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) + } n++ continue } @@ -569,7 +607,7 @@ func (e *endpoint) WritePackets(r *stack.Route, gso *stack.GSO, pkts stack.Packe return n + len(dropped), nil } -// WriteHeaderIncludedPacker implements stack.NetworkEndpoint. +// WriteHeaderIncludedPacket implements stack.NetworkEndpoint. func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBuffer) *tcpip.Error { // The packet already has an IP header, but there are a few required checks. h, ok := pkt.Data.PullUp(header.IPv6MinimumSize) @@ -607,22 +645,27 @@ func (e *endpoint) WriteHeaderIncludedPacket(r *stack.Route, pkt *stack.PacketBu // HandlePacket is called by the link layer when new ipv6 packets arrive for // this endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { if !e.isEnabled() { return } + pkt.NICID = e.nic.ID() + stats := e.protocol.stack.Stats() + h := header.IPv6(pkt.NetworkHeader().View()) if !h.IsValid(pkt.Data.Size() + pkt.NetworkHeader().View().Size() + pkt.TransportHeader().View().Size()) { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } + srcAddr := h.SourceAddress() + dstAddr := h.DestinationAddress() // As per RFC 4291 section 2.7: // Multicast addresses must not be used as source addresses in IPv6 // packets or appear in any Routing header. - if header.IsV6MulticastAddress(r.RemoteAddress) { - r.Stats().IP.InvalidSourceAddressesReceived.Increment() + if header.IsV6MulticastAddress(srcAddr) { + stats.IP.InvalidSourceAddressesReceived.Increment() return } @@ -641,7 +684,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { ipt := e.protocol.stack.IPTables() if ok := ipt.Check(stack.Input, pkt, nil, nil, "", ""); !ok { // iptables is telling us to drop the packet. - r.Stats().IP.IPTablesInputDropped.Increment() + stats.IP.IPTablesInputDropped.Increment() return } @@ -651,7 +694,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { previousHeaderStart := it.HeaderOffset() extHdr, done, err := it.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } if done { @@ -663,7 +706,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // As per RFC 8200 section 4.1, the Hop By Hop extension header is // restricted to appear immediately after an IPv6 fixed header. if previousHeaderStart != 0 { - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: previousHeaderStart, }, pkt) @@ -675,7 +718,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { for { opt, done, err := optsIt.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } if done { @@ -689,7 +732,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionDiscard: return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - if header.IsV6MulticastAddress(r.LocalAddress) { + if header.IsV6MulticastAddress(dstAddr) { return } fallthrough @@ -702,7 +745,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // ICMP Parameter Problem, Code 2, message to the packet's // Source Address, pointing to the unrecognized Option Type. // - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, @@ -727,7 +770,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // header, so we just make sure Segments Left is zero before processing // the next extension header. if extHdr.SegmentsLeft() != 0 { - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6ErroneousHeader, pointer: it.ParseOffset(), }, pkt) @@ -747,6 +790,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { continue } + fragmentFieldOffset := it.ParseOffset() + // Don't consume the iterator if we have the first fragment because we // will use it to validate that the first fragment holds the upper layer // header. @@ -762,8 +807,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { for { it, done, err := it.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } if done { @@ -790,8 +835,8 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { switch lastHdr.(type) { case header.IPv6RawPayloadHeader: default: - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } } @@ -799,30 +844,70 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { fragmentPayloadLen := rawPayload.Buf.Size() if fragmentPayloadLen == 0 { // Drop the packet as it's marked as a fragment but has no payload. - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() + return + } + + // As per RFC 2460 Section 4.5: + // + // If the length of a fragment, as derived from the fragment packet's + // Payload Length field, is not a multiple of 8 octets and the M flag + // of that fragment is 1, then that fragment must be discarded and an + // ICMP Parameter Problem, Code 0, message should be sent to the source + // of the fragment, pointing to the Payload Length field of the + // fragment packet. + if extHdr.More() && fragmentPayloadLen%header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit != 0 { + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() + _ = e.protocol.returnError(&icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: header.IPv6PayloadLenOffset, + }, pkt) return } // The packet is a fragment, let's try to reassemble it. start := extHdr.FragmentOffset() * header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit - // Drop the fragment if the size of the reassembled payload would exceed - // the maximum payload size. + // As per RFC 2460 Section 4.5: + // + // If the length and offset of a fragment are such that the Payload + // Length of the packet reassembled from that fragment would exceed + // 65,535 octets, then that fragment must be discarded and an ICMP + // Parameter Problem, Code 0, message should be sent to the source of + // the fragment, pointing to the Fragment Offset field of the fragment + // packet. if int(start)+fragmentPayloadLen > header.IPv6MaximumPayloadSize { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() + _ = e.protocol.returnError(&icmpReasonParameterProblem{ + code: header.ICMPv6ErroneousHeader, + pointer: fragmentFieldOffset, + }, pkt) return } + // Set up a callback in case we need to send a Time Exceeded Message as + // per RFC 2460 Section 4.5. + var releaseCB func(bool) + if start == 0 { + pkt := pkt.Clone() + releaseCB = func(timedOut bool) { + if timedOut { + _ = e.protocol.returnError(&icmpReasonReassemblyTimeout{}, pkt) + } + } + } + // Note that pkt doesn't have its transport header set after reassembly, // and won't until DeliverNetworkPacket sets it. data, proto, ready, err := e.protocol.fragmentation.Process( // IPv6 ignores the Protocol field since the ID only needs to be unique // across source-destination pairs, as per RFC 8200 section 4.5. fragmentation.FragmentID{ - Source: h.SourceAddress(), - Destination: h.DestinationAddress(), + Source: srcAddr, + Destination: dstAddr, ID: extHdr.ID(), }, start, @@ -830,10 +915,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { extHdr.More(), uint8(rawPayload.Identifier), rawPayload.Buf, + releaseCB, ) if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() - r.Stats().IP.MalformedFragmentsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedFragmentsReceived.Increment() return } pkt.Data = data @@ -852,7 +938,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { for { opt, done, err := optsIt.Next() if err != nil { - r.Stats().IP.MalformedPacketsReceived.Increment() + stats.IP.MalformedPacketsReceived.Increment() return } if done { @@ -866,7 +952,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { case header.IPv6OptionUnknownActionDiscard: return case header.IPv6OptionUnknownActionDiscardSendICMPNoMulticastDest: - if header.IsV6MulticastAddress(r.LocalAddress) { + if header.IsV6MulticastAddress(dstAddr) { return } fallthrough @@ -879,7 +965,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // ICMP Parameter Problem, Code 2, message to the packet's // Source Address, pointing to the unrecognized Option Type. // - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownOption, pointer: it.ParseOffset() + optsIt.OptionOffset(), respondToMulticast: true, @@ -902,13 +988,13 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { extHdr.Buf.TrimFront(pkt.TransportHeader().View().Size()) pkt.Data = extHdr.Buf - r.Stats().IP.PacketsDelivered.Increment() + stats.IP.PacketsDelivered.Increment() if p := tcpip.TransportProtocolNumber(extHdr.Identifier); p == header.ICMPv6ProtocolNumber { pkt.TransportProtocolNumber = p - e.handleICMP(r, pkt, hasFragmentHeader) + e.handleICMP(pkt, hasFragmentHeader) } else { - r.Stats().IP.PacketsDelivered.Increment() - switch res := e.dispatcher.DeliverTransportPacket(r, p, pkt); res { + stats.IP.PacketsDelivered.Increment() + switch res := e.dispatcher.DeliverTransportPacket(p, pkt); res { case stack.TransportPacketHandled: case stack.TransportPacketDestinationPortUnreachable: // As per RFC 4443 section 3.1: @@ -916,7 +1002,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // message with Code 4 in response to a packet for which the // transport protocol (e.g., UDP) has no listener, if that transport // protocol has no alternative means to inform the sender. - _ = e.protocol.returnError(r, &icmpReasonPortUnreachable{}, pkt) + _ = e.protocol.returnError(&icmpReasonPortUnreachable{}, pkt) case stack.TransportPacketProtocolUnreachable: // As per RFC 8200 section 4. (page 7): // Extension headers are numbered from IANA IP Protocol Numbers @@ -937,7 +1023,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { // // Which when taken together indicate that an unknown protocol should // be treated as an unrecognized next header value. - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: it.ParseOffset(), }, pkt) @@ -947,11 +1033,11 @@ func (e *endpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { } default: - _ = e.protocol.returnError(r, &icmpReasonParameterProblem{ + _ = e.protocol.returnError(&icmpReasonParameterProblem{ code: header.ICMPv6UnknownHeader, pointer: it.ParseOffset(), }, pkt) - r.Stats().UnknownProtocolRcvdPackets.Increment() + stats.UnknownProtocolRcvdPackets.Increment() return } } @@ -1427,14 +1513,31 @@ func (p *protocol) SetForwarding(v bool) { } } -// calculateMTU calculates the network-layer payload MTU based on the link-layer -// payload mtu. -func calculateMTU(mtu uint32) uint32 { - mtu -= header.IPv6MinimumSize - if mtu <= maxPayloadSize { - return mtu +// calculateNetworkMTU calculates the network-layer payload MTU based on the +// link-layer payload MTU and the length of every IPv6 header. +// Note that this is different than the Payload Length field of the IPv6 header, +// which includes the length of the extension headers. +func calculateNetworkMTU(linkMTU, networkHeadersLen uint32) (uint32, *tcpip.Error) { + if linkMTU < header.IPv6MinimumMTU { + return 0, tcpip.ErrInvalidEndpointState + } + + // As per RFC 7112 section 5, we should discard packets if their IPv6 header + // is bigger than 1280 bytes (ie, the minimum link MTU) since we do not + // support PMTU discovery: + // Hosts that do not discover the Path MTU MUST limit the IPv6 Header Chain + // length to 1280 bytes. Limiting the IPv6 Header Chain length to 1280 + // bytes ensures that the header chain length does not exceed the IPv6 + // minimum MTU. + if networkHeadersLen > header.IPv6MinimumMTU { + return 0, tcpip.ErrMalformedHeader + } + + networkMTU := linkMTU - uint32(networkHeadersLen) + if networkMTU > maxPayloadSize { + networkMTU = maxPayloadSize } - return maxPayloadSize + return networkMTU, nil } // Options holds options to configure a new protocol. @@ -1488,7 +1591,7 @@ func NewProtocolWithOptions(opts Options) stack.NetworkProtocolFactory { return func(s *stack.Stack) stack.NetworkProtocol { p := &protocol{ stack: s, - fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, reassembleTimeout, s.Clock()), + fragmentation: fragmentation.NewFragmentation(header.IPv6FragmentExtHdrFragmentOffsetBytesPerUnit, fragmentation.HighFragThreshold, fragmentation.LowFragThreshold, ReassembleTimeout, s.Clock()), ids: ids, hashIV: hashIV, @@ -1509,23 +1612,6 @@ func NewProtocol(s *stack.Stack) stack.NetworkProtocol { return NewProtocolWithOptions(Options{})(s) } -// calculateFragmentInnerMTU calculates the maximum number of bytes of -// fragmentable data a fragment can have, based on the link layer mtu and pkt's -// network header size. -func calculateFragmentInnerMTU(mtu uint32, pkt *stack.PacketBuffer) uint32 { - // TODO(gvisor.dev/issue/3912): Once the Authentication or ESP Headers are - // supported for outbound packets, their length should not affect the fragment - // MTU because they should only be transmitted once. - mtu -= uint32(pkt.NetworkHeader().View().Size()) - mtu -= header.IPv6FragmentHeaderSize - // Round the MTU down to align to 8 bytes. - mtu &^= 7 - if mtu <= maxPayloadSize { - return mtu - } - return maxPayloadSize -} - func calculateFragmentReserve(pkt *stack.PacketBuffer) int { return pkt.AvailableHeaderBytes() + pkt.NetworkHeader().View().Size() + header.IPv6FragmentHeaderSize } @@ -1560,6 +1646,7 @@ func buildNextFragment(pf *fragmentation.PacketFragmenter, originalIPHeaders hea originalIPHeadersLength := len(originalIPHeaders) fragmentIPHeadersLength := originalIPHeadersLength + header.IPv6FragmentHeaderSize fragmentIPHeaders := header.IPv6(fragPkt.NetworkHeader().Push(fragmentIPHeadersLength)) + fragPkt.NetworkProtocolNumber = ProtocolNumber // Copy the IPv6 header and any extension headers already populated. if copied := copy(fragmentIPHeaders, originalIPHeaders); copied != originalIPHeadersLength { diff --git a/pkg/tcpip/network/ipv6/ipv6_test.go b/pkg/tcpip/network/ipv6/ipv6_test.go index 297868f24..1bfcdde25 100644 --- a/pkg/tcpip/network/ipv6/ipv6_test.go +++ b/pkg/tcpip/network/ipv6/ipv6_test.go @@ -24,6 +24,7 @@ import ( "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/checker" + "gvisor.dev/gvisor/pkg/tcpip/faketime" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/link/channel" "gvisor.dev/gvisor/pkg/tcpip/network/testutil" @@ -238,7 +239,7 @@ func TestReceiveOnAllNodesMulticastAddr(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) - e := channel.New(10, 1280, linkAddr1) + e := channel.New(10, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(1, e); err != nil { t.Fatalf("CreateNIC(_) = %s", err) } @@ -271,7 +272,7 @@ func TestReceiveOnSolicitedNodeAddr(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{test.protocolFactory}, }) - e := channel.New(1, 1280, linkAddr1) + e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -825,7 +826,7 @@ func TestReceiveIPv6ExtHdrs(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(1, 1280, linkAddr1) + e := channel.New(1, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -1844,7 +1845,7 @@ func TestReceiveIPv6Fragments(t *testing.T) { NetworkProtocols: []stack.NetworkProtocolFactory{NewProtocol}, TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, }) - e := channel.New(0, 1280, linkAddr1) + e := channel.New(0, header.IPv6MinimumMTU, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } @@ -1912,16 +1913,19 @@ func TestReceiveIPv6Fragments(t *testing.T) { func TestInvalidIPv6Fragments(t *testing.T) { const ( - nicID = 1 - fragmentExtHdrLen = 8 + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + nicID = 1 + hoplimit = 255 + ident = 1 + data = "TEST_INVALID_IPV6_FRAGMENTS" ) - payloadGen := func(payloadLen int) []byte { - payload := make([]byte, payloadLen) - for i := 0; i < len(payload); i++ { - payload[i] = 0x30 - } - return payload + type fragmentData struct { + ipv6Fields header.IPv6Fields + ipv6FragmentFields header.IPv6FragmentFields + payload []byte } tests := []struct { @@ -1929,31 +1933,64 @@ func TestInvalidIPv6Fragments(t *testing.T) { fragments []fragmentData wantMalformedIPPackets uint64 wantMalformedFragments uint64 + expectICMP bool + expectICMPType header.ICMPv6Type + expectICMPCode header.ICMPv6Code + expectICMPTypeSpecific uint32 }{ { + name: "fragment size is not a multiple of 8 and the M flag is true", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 9, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0 >> 3, + M: true, + Identification: ident, + }, + payload: []byte(data)[:9], + }, + }, + wantMalformedIPPackets: 1, + wantMalformedFragments: 1, + expectICMP: true, + expectICMPType: header.ICMPv6ParamProblem, + expectICMPCode: header.ICMPv6ErroneousHeader, + expectICMPTypeSpecific: header.IPv6PayloadLenOffset, + }, + { name: "fragments reassembled into a payload exceeding the max IPv6 payload size", fragments: []fragmentData{ { - srcAddr: addr1, - dstAddr: addr2, - nextHdr: fragmentExtHdrID, - data: buffer.NewVectorisedView( - fragmentExtHdrLen+(header.IPv6MaximumPayloadSize+1)-16, - []buffer.View{ - // Fragment extension header. - // Fragment offset = 8190, More = false, ID = 1 - buffer.View([]byte{uint8(header.UDPProtocolNumber), 0, - ((header.IPv6MaximumPayloadSize + 1) - 16) >> 8, - ((header.IPv6MaximumPayloadSize + 1) - 16) & math.MaxUint8, - 0, 0, 0, 1}), - // Payload length = 16 - payloadGen(16), - }, - ), + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: ((header.IPv6MaximumPayloadSize + 1) - 16) >> 3, + M: false, + Identification: ident, + }, + payload: []byte(data)[:16], }, }, wantMalformedIPPackets: 1, wantMalformedFragments: 1, + expectICMP: true, + expectICMPType: header.ICMPv6ParamProblem, + expectICMPCode: header.ICMPv6ErroneousHeader, + expectICMPTypeSpecific: header.IPv6MinimumSize + 2, /* offset for 'Fragment Offset' in the fragment header */ }, } @@ -1964,33 +2001,40 @@ func TestInvalidIPv6Fragments(t *testing.T) { NewProtocol, }, }) - e := channel.New(0, 1500, linkAddr1) + e := channel.New(1, 1500, linkAddr1) if err := s.CreateNIC(nicID, e); err != nil { t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) } if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, addr2, err) } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }}) + var expectICMPPayload buffer.View for _, f := range test.fragments { - hdr := buffer.NewPrependable(header.IPv6MinimumSize) + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) - // Serialize IPv6 fixed header. - ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize)) - ip.Encode(&header.IPv6Fields{ - PayloadLength: uint16(f.data.Size()), - NextHeader: f.nextHdr, - HopLimit: 255, - SrcAddr: f.srcAddr, - DstAddr: f.dstAddr, - }) + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) + ip.Encode(&f.ipv6Fields) + + fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) + fragHDR.Encode(&f.ipv6FragmentFields) vv := hdr.View().ToVectorisedView() - vv.Append(f.data) + vv.AppendView(f.payload) - e.InjectInbound(ProtocolNumber, stack.NewPacketBuffer(stack.PacketBufferOptions{ + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ Data: vv, - })) + }) + + if test.expectICMP { + expectICMPPayload = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(ProtocolNumber, pkt) } if got, want := s.Stats().IP.MalformedPacketsReceived.Value(), test.wantMalformedIPPackets; got != want { @@ -1999,6 +2043,287 @@ func TestInvalidIPv6Fragments(t *testing.T) { if got, want := s.Stats().IP.MalformedFragmentsReceived.Value(), test.wantMalformedFragments; got != want { t.Errorf("got Stats.IP.MalformedFragmentsReceived = %d, want = %d", got, want) } + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + + checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+expectICMPPayload.Size())), + checker.ICMPv6( + checker.ICMPv6Type(test.expectICMPType), + checker.ICMPv6Code(test.expectICMPCode), + checker.ICMPv6TypeSpecific(test.expectICMPTypeSpecific), + checker.ICMPv6Payload([]byte(expectICMPPayload)), + ), + ) + }) + } +} + +func TestFragmentReassemblyTimeout(t *testing.T) { + const ( + addr1 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01" + addr2 = "\x0a\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x02" + linkAddr1 = tcpip.LinkAddress("\x0a\x0b\x0c\x0d\x0e\x0e") + nicID = 1 + hoplimit = 255 + ident = 1 + data = "TEST_FRAGMENT_REASSEMBLY_TIMEOUT" + ) + + type fragmentData struct { + ipv6Fields header.IPv6Fields + ipv6FragmentFields header.IPv6FragmentFields + payload []byte + } + + tests := []struct { + name string + fragments []fragmentData + expectICMP bool + }{ + { + name: "first fragment only", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "two first fragments", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + { + name: "second fragment only", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: false, + }, + { + name: "two fragments with a gap", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + }, + expectICMP: true, + }, + { + name: "two fragments with a gap in reverse order", + fragments: []fragmentData{ + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: uint16(header.IPv6FragmentHeaderSize + len(data) - 16), + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 8, + M: false, + Identification: ident, + }, + payload: []byte(data)[16:], + }, + { + ipv6Fields: header.IPv6Fields{ + PayloadLength: header.IPv6FragmentHeaderSize + 16, + NextHeader: header.IPv6FragmentHeader, + HopLimit: hoplimit, + SrcAddr: addr1, + DstAddr: addr2, + }, + ipv6FragmentFields: header.IPv6FragmentFields{ + NextHeader: uint8(header.UDPProtocolNumber), + FragmentOffset: 0, + M: true, + Identification: ident, + }, + payload: []byte(data)[:16], + }, + }, + expectICMP: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + clock := faketime.NewManualClock() + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ + NewProtocol, + }, + Clock: clock, + }) + + e := channel.New(1, 1500, linkAddr1) + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("CreateNIC(%d, _) = %s", nicID, err) + } + if err := s.AddAddress(nicID, ProtocolNumber, addr2); err != nil { + t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, header.IPv6ProtocolNumber, addr2, err) + } + s.SetRouteTable([]tcpip.Route{{ + Destination: header.IPv6EmptySubnet, + NIC: nicID, + }}) + + var firstFragmentSent buffer.View + for _, f := range test.fragments { + hdr := buffer.NewPrependable(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize) + + ip := header.IPv6(hdr.Prepend(header.IPv6MinimumSize + header.IPv6FragmentHeaderSize)) + ip.Encode(&f.ipv6Fields) + + fragHDR := header.IPv6Fragment(hdr.View()[header.IPv6MinimumSize:]) + fragHDR.Encode(&f.ipv6FragmentFields) + + vv := hdr.View().ToVectorisedView() + vv.AppendView(f.payload) + + pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ + Data: vv, + }) + + if firstFragmentSent == nil && fragHDR.FragmentOffset() == 0 { + firstFragmentSent = stack.PayloadSince(pkt.NetworkHeader()) + } + + e.InjectInbound(ProtocolNumber, pkt) + } + + clock.Advance(ReassembleTimeout) + + reply, ok := e.Read() + if !test.expectICMP { + if ok { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + return + } + if !ok { + t.Fatal("expected ICMP error message missing") + } + if firstFragmentSent == nil { + t.Fatalf("unexpected ICMP error message received: %#v", reply) + } + + checker.IPv6(t, stack.PayloadSince(reply.Pkt.NetworkHeader()), + checker.SrcAddr(addr2), + checker.DstAddr(addr1), + checker.IPFullLength(uint16(header.IPv6MinimumSize+header.ICMPv6MinimumSize+firstFragmentSent.Size())), + checker.ICMPv6( + checker.ICMPv6Type(header.ICMPv6TimeExceeded), + checker.ICMPv6Code(header.ICMPv6ReassemblyTimeout), + checker.ICMPv6Payload([]byte(firstFragmentSent)), + ), + ) }) } } @@ -2035,13 +2360,10 @@ func TestWriteStats(t *testing.T) { // Install Output DROP rule. t.Helper() ipt := stk.IPTables() - filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) - if !ok { - t.Fatalf("failed to find filter table") - } + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} - if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %v", err) } }, @@ -2056,17 +2378,14 @@ func TestWriteStats(t *testing.T) { // of the 3 packets. t.Helper() ipt := stk.IPTables() - filter, ok := ipt.GetTable(stack.FilterTable, true /* ipv6 */) - if !ok { - t.Fatalf("failed to find filter table") - } + filter := ipt.GetTable(stack.FilterID, true /* ipv6 */) // We'll match and DROP the last packet. ruleIdx := filter.BuiltinChains[stack.Output] filter.Rules[ruleIdx].Target = &stack.DropTarget{} filter.Rules[ruleIdx].Matchers = []stack.Matcher{&limitedMatcher{nPackets - 1}} // Make sure the next rule is ACCEPT. filter.Rules[ruleIdx+1].Target = &stack.AcceptTarget{} - if err := ipt.ReplaceTable(stack.FilterTable, filter, true /* ipv6 */); err != nil { + if err := ipt.ReplaceTable(stack.FilterID, filter, true /* ipv6 */); err != nil { t.Fatalf("failed to replace table: %v", err) } }, @@ -2230,8 +2549,8 @@ var fragmentationTests = []struct { wantFragments []fragmentInfo }{ { - description: "No Fragmentation", - mtu: 1280, + description: "No fragmentation", + mtu: header.IPv6MinimumMTU, gso: nil, transHdrLen: 0, payloadSize: 1000, @@ -2241,7 +2560,18 @@ var fragmentationTests = []struct { }, { description: "Fragmented", - mtu: 1280, + mtu: header.IPv6MinimumMTU, + gso: nil, + transHdrLen: 0, + payloadSize: 2000, + wantFragments: []fragmentInfo{ + {offset: 0, payloadSize: 1240, more: true}, + {offset: 154, payloadSize: 776, more: false}, + }, + }, + { + description: "Fragmented with mtu not a multiple of 8", + mtu: header.IPv6MinimumMTU + 1, gso: nil, transHdrLen: 0, payloadSize: 2000, @@ -2262,7 +2592,7 @@ var fragmentationTests = []struct { }, { description: "Fragmented with gso none", - mtu: 1280, + mtu: header.IPv6MinimumMTU, gso: &stack.GSO{Type: stack.GSONone}, transHdrLen: 0, payloadSize: 1400, @@ -2273,7 +2603,7 @@ var fragmentationTests = []struct { }, { description: "Fragmented with big header", - mtu: 1280, + mtu: header.IPv6MinimumMTU, gso: nil, transHdrLen: 100, payloadSize: 1200, @@ -2448,8 +2778,8 @@ func TestFragmentationErrors(t *testing.T) { wantError: tcpip.ErrAborted, }, { - description: "Error on packet with MTU smaller than transport header", - mtu: 1280, + description: "Error when MTU is smaller than transport header", + mtu: header.IPv6MinimumMTU, transHdrLen: 1500, payloadSize: 500, allowPackets: 0, @@ -2457,6 +2787,16 @@ func TestFragmentationErrors(t *testing.T) { mockError: nil, wantError: tcpip.ErrMessageTooLong, }, + { + description: "Error when MTU is smaller than IPv6 minimum MTU", + mtu: header.IPv6MinimumMTU - 1, + transHdrLen: 0, + payloadSize: 500, + allowPackets: 0, + outgoingErrors: 1, + mockError: nil, + wantError: tcpip.ErrInvalidEndpointState, + }, } for _, ft := range tests { diff --git a/pkg/tcpip/network/ipv6/ndp_test.go b/pkg/tcpip/network/ipv6/ndp_test.go index ac20f217e..981d1371a 100644 --- a/pkg/tcpip/network/ipv6/ndp_test.go +++ b/pkg/tcpip/network/ipv6/ndp_test.go @@ -341,7 +341,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi if diff := cmp.Diff(existing, n); diff != "" { t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) } - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing) + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing) } neighborByAddr[n.Addr] = n } @@ -368,7 +368,7 @@ func TestNeighorSolicitationWithSourceLinkLayerOptionUsingNeighborCache(t *testi } if ok { - t.Fatalf("unexpectedly got neighbor entry: %s", neigh) + t.Fatalf("unexpectedly got neighbor entry: %#v", neigh) } } }) @@ -573,6 +573,13 @@ func TestNeighorSolicitationResponse(t *testing.T) { t.Fatalf("AddAddress(%d, %d, %s) = %s", nicID, ProtocolNumber, nicAddr, err) } + s.SetRouteTable([]tcpip.Route{ + tcpip.Route{ + Destination: header.IPv6EmptySubnet, + NIC: 1, + }, + }) + ndpNSSize := header.ICMPv6NeighborSolicitMinimumSize + test.nsOpts.Length() hdr := buffer.NewPrependable(header.IPv6MinimumSize + ndpNSSize) pkt := header.ICMPv6(hdr.Prepend(ndpNSSize)) @@ -913,13 +920,13 @@ func TestNeighorAdvertisementWithTargetLinkLayerOptionUsingNeighborCache(t *test if diff := cmp.Diff(existing, n); diff != "" { t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry (-existing +got):\n%s", nicID, diff) } - t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %s", nicID, existing) + t.Fatalf("s.Neighbors(%d) returned unexpected duplicate neighbor entry: %#v", nicID, existing) } neighborByAddr[n.Addr] = n } if neigh, ok := neighborByAddr[lladdr1]; ok { - t.Fatalf("unexpectedly got neighbor entry: %s", neigh) + t.Fatalf("unexpectedly got neighbor entry: %#v", neigh) } if test.isValid { @@ -993,7 +1000,8 @@ func TestNDPValidation(t *testing.T) { if n := copy(ip[header.IPv6MinimumSize:], extensions); n != len(extensions) { t.Fatalf("expected to write %d bytes of extensions, but wrote %d", len(extensions), n) } - ep.HandlePacket(r, pkt) + r.PopulatePacketInfo(pkt) + ep.HandlePacket(pkt) } var tllData [header.NDPLinkLayerAddressSize]byte diff --git a/pkg/tcpip/stack/addressable_endpoint_state.go b/pkg/tcpip/stack/addressable_endpoint_state.go index 4d3acab96..9478f3fb7 100644 --- a/pkg/tcpip/stack/addressable_endpoint_state.go +++ b/pkg/tcpip/stack/addressable_endpoint_state.go @@ -272,6 +272,9 @@ func (a *AddressableEndpointState) addAndAcquireAddressLocked(addr tcpip.Address addrState = &addressState{ addressableEndpointState: a, addr: addr, + // Cache the subnet in addrState to avoid calls to addr.Subnet() as that + // results in allocations on every call. + subnet: addr.Subnet(), } a.mu.endpoints[addr.Address] = addrState addrState.mu.Lock() @@ -361,6 +364,8 @@ func (a *AddressableEndpointState) RemovePermanentEndpoint(ep AddressEndpoint) * return tcpip.ErrInvalidEndpointState } + a.mu.Lock() + defer a.mu.Unlock() return a.removePermanentEndpointLocked(addrState) } @@ -664,7 +669,7 @@ var _ AddressEndpoint = (*addressState)(nil) type addressState struct { addressableEndpointState *AddressableEndpointState addr tcpip.AddressWithPrefix - + subnet tcpip.Subnet // Lock ordering (from outer to inner lock ordering): // // AddressableEndpointState.mu @@ -684,6 +689,11 @@ func (a *addressState) AddressWithPrefix() tcpip.AddressWithPrefix { return a.addr } +// Subnet implements AddressEndpoint. +func (a *addressState) Subnet() tcpip.Subnet { + return a.subnet +} + // GetKind implements AddressEndpoint. func (a *addressState) GetKind() AddressKind { a.mu.RLock() diff --git a/pkg/tcpip/stack/conntrack.go b/pkg/tcpip/stack/conntrack.go index 0cd1da11f..9a17efcba 100644 --- a/pkg/tcpip/stack/conntrack.go +++ b/pkg/tcpip/stack/conntrack.go @@ -269,7 +269,7 @@ func (ct *ConnTrack) connForTID(tid tupleID) (*conn, direction) { return nil, dirOriginal } -func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *RedirectTarget) *conn { +func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, port uint16, address tcpip.Address) *conn { tid, err := packetToTupleID(pkt) if err != nil { return nil @@ -282,8 +282,8 @@ func (ct *ConnTrack) insertRedirectConn(pkt *PacketBuffer, hook Hook, rt *Redire // rule. This tuple will be used to manipulate the packet in // handlePacket. replyTID := tid.reply() - replyTID.srcAddr = rt.Addr - replyTID.srcPort = rt.Port + replyTID.srcAddr = address + replyTID.srcPort = port var manip manipType switch hook { case Prerouting: @@ -401,12 +401,12 @@ func handlePacketOutput(pkt *PacketBuffer, conn *conn, gso *GSO, r *Route, dir d // Calculate the TCP checksum and set it. tcpHeader.SetChecksum(0) - length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := r.PseudoHeaderChecksum(header.TCPProtocolNumber, length) + length := uint16(len(tcpHeader) + pkt.Data.Size()) + xsum := header.PseudoHeaderChecksum(header.TCPProtocolNumber, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) if gso != nil && gso.NeedsCsum { tcpHeader.SetChecksum(xsum) - } else if r.Capabilities()&CapabilityTXChecksumOffload == 0 { - xsum = header.ChecksumVVWithOffset(pkt.Data, xsum, int(tcpHeader.DataOffset()), pkt.Data.Size()) + } else if r.RequiresTXTransportChecksum() { + xsum = header.ChecksumVV(pkt.Data, xsum) tcpHeader.SetChecksum(^tcpHeader.CalculateChecksum(xsum)) } diff --git a/pkg/tcpip/stack/forwarding_test.go b/pkg/tcpip/stack/forwarding_test.go index cf042309e..7a501acdc 100644 --- a/pkg/tcpip/stack/forwarding_test.go +++ b/pkg/tcpip/stack/forwarding_test.go @@ -73,9 +73,9 @@ func (*fwdTestNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fwdTestNetworkEndpoint) HandlePacket(r *Route, pkt *PacketBuffer) { +func (f *fwdTestNetworkEndpoint) HandlePacket(pkt *PacketBuffer) { // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fwdTestNetworkEndpoint) MaxHeaderLength() uint16 { @@ -178,7 +178,7 @@ func (*fwdTestNetworkProtocol) Close() {} func (*fwdTestNetworkProtocol) Wait() {} -func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (f *fwdTestNetworkProtocol) LinkAddressRequest(addr, _ tcpip.Address, remoteLinkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { if f.onLinkAddressResolved != nil { time.AfterFunc(f.addrResolveDelay, func() { f.onLinkAddressResolved(f.addrCache, f.neigh, addr, remoteLinkAddr) diff --git a/pkg/tcpip/stack/iptables.go b/pkg/tcpip/stack/iptables.go index 8d6d9a7f1..2d8c883cd 100644 --- a/pkg/tcpip/stack/iptables.go +++ b/pkg/tcpip/stack/iptables.go @@ -22,30 +22,17 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) -// tableID is an index into IPTables.tables. -type tableID int +// TableID identifies a specific table. +type TableID int +// Each value identifies a specific table. const ( - natID tableID = iota - mangleID - filterID - numTables + NATID TableID = iota + MangleID + FilterID + NumTables ) -// Table names. -const ( - NATTable = "nat" - MangleTable = "mangle" - FilterTable = "filter" -) - -// nameToID is immutable. -var nameToID = map[string]tableID{ - NATTable: natID, - MangleTable: mangleID, - FilterTable: filterID, -} - // HookUnset indicates that there is no hook set for an entrypoint or // underflow. const HookUnset = -1 @@ -57,8 +44,8 @@ const reaperDelay = 5 * time.Second // all packets. func DefaultTables() *IPTables { return &IPTables{ - v4Tables: [numTables]Table{ - natID: Table{ + v4Tables: [NumTables]Table{ + NATID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, @@ -81,7 +68,7 @@ func DefaultTables() *IPTables { Postrouting: 3, }, }, - mangleID: Table{ + MangleID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, @@ -99,7 +86,7 @@ func DefaultTables() *IPTables { Postrouting: HookUnset, }, }, - filterID: Table{ + FilterID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv4ProtocolNumber}}, @@ -122,8 +109,8 @@ func DefaultTables() *IPTables { }, }, }, - v6Tables: [numTables]Table{ - natID: Table{ + v6Tables: [NumTables]Table{ + NATID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, @@ -146,7 +133,7 @@ func DefaultTables() *IPTables { Postrouting: 3, }, }, - mangleID: Table{ + MangleID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, @@ -164,7 +151,7 @@ func DefaultTables() *IPTables { Postrouting: HookUnset, }, }, - filterID: Table{ + FilterID: Table{ Rules: []Rule{ Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, Rule{Target: &AcceptTarget{NetworkProtocol: header.IPv6ProtocolNumber}}, @@ -187,10 +174,10 @@ func DefaultTables() *IPTables { }, }, }, - priorities: [NumHooks][]tableID{ - Prerouting: []tableID{mangleID, natID}, - Input: []tableID{natID, filterID}, - Output: []tableID{mangleID, natID, filterID}, + priorities: [NumHooks][]TableID{ + Prerouting: []TableID{MangleID, NATID}, + Input: []TableID{NATID, FilterID}, + Output: []TableID{MangleID, NATID, FilterID}, }, connections: ConnTrack{ seed: generateRandUint32(), @@ -229,26 +216,20 @@ func EmptyNATTable() Table { } } -// GetTable returns a table by name. -func (it *IPTables) GetTable(name string, ipv6 bool) (Table, bool) { - id, ok := nameToID[name] - if !ok { - return Table{}, false - } +// GetTable returns a table with the given id and IP version. It panics when an +// invalid id is provided. +func (it *IPTables) GetTable(id TableID, ipv6 bool) Table { it.mu.RLock() defer it.mu.RUnlock() if ipv6 { - return it.v6Tables[id], true + return it.v6Tables[id] } - return it.v4Tables[id], true + return it.v4Tables[id] } -// ReplaceTable replaces or inserts table by name. -func (it *IPTables) ReplaceTable(name string, table Table, ipv6 bool) *tcpip.Error { - id, ok := nameToID[name] - if !ok { - return tcpip.ErrInvalidOptionValue - } +// ReplaceTable replaces or inserts table by name. It panics when an invalid id +// is provided. +func (it *IPTables) ReplaceTable(id TableID, table Table, ipv6 bool) *tcpip.Error { it.mu.Lock() defer it.mu.Unlock() // If iptables is being enabled, initialize the conntrack table and @@ -311,7 +292,7 @@ func (it *IPTables) Check(hook Hook, pkt *PacketBuffer, gso *GSO, r *Route, prer for _, tableID := range priorities { // If handlePacket already NATed the packet, we don't need to // check the NAT table. - if tableID == natID && pkt.NatDone { + if tableID == NATID && pkt.NatDone { continue } var table Table diff --git a/pkg/tcpip/stack/iptables_targets.go b/pkg/tcpip/stack/iptables_targets.go index 538c4625d..d63e9757c 100644 --- a/pkg/tcpip/stack/iptables_targets.go +++ b/pkg/tcpip/stack/iptables_targets.go @@ -15,6 +15,8 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/log" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -26,13 +28,6 @@ type AcceptTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (at *AcceptTarget) ID() TargetID { - return TargetID{ - NetworkProtocol: at.NetworkProtocol, - } -} - // Action implements Target.Action. func (*AcceptTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleAccept, 0 @@ -44,22 +39,11 @@ type DropTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (dt *DropTarget) ID() TargetID { - return TargetID{ - NetworkProtocol: dt.NetworkProtocol, - } -} - // Action implements Target.Action. func (*DropTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleDrop, 0 } -// ErrorTargetName is used to mark targets as error targets. Error targets -// shouldn't be reached - an error has occurred if we fall through to one. -const ErrorTargetName = "ERROR" - // ErrorTarget logs an error and drops the packet. It represents a target that // should be unreachable. type ErrorTarget struct { @@ -67,14 +51,6 @@ type ErrorTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (et *ErrorTarget) ID() TargetID { - return TargetID{ - Name: ErrorTargetName, - NetworkProtocol: et.NetworkProtocol, - } -} - // Action implements Target.Action. func (*ErrorTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { log.Debugf("ErrorTarget triggered.") @@ -90,14 +66,6 @@ type UserChainTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (uc *UserChainTarget) ID() TargetID { - return TargetID{ - Name: ErrorTargetName, - NetworkProtocol: uc.NetworkProtocol, - } -} - // Action implements Target.Action. func (*UserChainTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { panic("UserChainTarget should never be called.") @@ -110,50 +78,39 @@ type ReturnTarget struct { NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (rt *ReturnTarget) ID() TargetID { - return TargetID{ - NetworkProtocol: rt.NetworkProtocol, - } -} - // Action implements Target.Action. func (*ReturnTarget) Action(*PacketBuffer, *ConnTrack, Hook, *GSO, *Route, tcpip.Address) (RuleVerdict, int) { return RuleReturn, 0 } -// RedirectTargetName is used to mark targets as redirect targets. Redirect -// targets should be reached for only NAT and Mangle tables. These targets will -// change the destination port/destination IP for packets. -const RedirectTargetName = "REDIRECT" - -// RedirectTarget redirects the packet by modifying the destination port/IP. +// RedirectTarget redirects the packet to this machine by modifying the +// destination port/IP. Outgoing packets are redirected to the loopback device, +// and incoming packets are redirected to the incoming interface (rather than +// forwarded). +// // TODO(gvisor.dev/issue/170): Other flags need to be added after we support // them. type RedirectTarget struct { - // Addr indicates address used to redirect. - Addr tcpip.Address - - // Port indicates port used to redirect. + // Port indicates port used to redirect. It is immutable. Port uint16 - // NetworkProtocol is the network protocol the target is used with. + // NetworkProtocol is the network protocol the target is used with. It + // is immutable. NetworkProtocol tcpip.NetworkProtocolNumber } -// ID implements Target.ID. -func (rt *RedirectTarget) ID() TargetID { - return TargetID{ - Name: RedirectTargetName, - NetworkProtocol: rt.NetworkProtocol, - } -} - // Action implements Target.Action. // TODO(gvisor.dev/issue/170): Parse headers without copying. The current -// implementation only works for PREROUTING and calls pkt.Clone(), neither +// implementation only works for Prerouting and calls pkt.Clone(), neither // of which should be the case. func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gso *GSO, r *Route, address tcpip.Address) (RuleVerdict, int) { + // Sanity check. + if rt.NetworkProtocol != pkt.NetworkProtocolNumber { + panic(fmt.Sprintf( + "RedirectTarget.Action with NetworkProtocol %d called on packet with NetworkProtocolNumber %d", + rt.NetworkProtocol, pkt.NetworkProtocolNumber)) + } + // Packet is already manipulated. if pkt.NatDone { return RuleAccept, 0 @@ -164,17 +121,17 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs return RuleDrop, 0 } - // Change the address to localhost (127.0.0.1 or ::1) in Output and to + // Change the address to loopback (127.0.0.1 or ::1) in Output and to // the primary address of the incoming interface in Prerouting. switch hook { case Output: if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { - rt.Addr = tcpip.Address([]byte{127, 0, 0, 1}) + address = tcpip.Address([]byte{127, 0, 0, 1}) } else { - rt.Addr = header.IPv6Loopback + address = header.IPv6Loopback } case Prerouting: - rt.Addr = address + // No-op, as address is already set correctly. default: panic("redirect target is supported only on output and prerouting hooks") } @@ -189,21 +146,18 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs // Calculate UDP checksum and set it. if hook == Output { udpHeader.SetChecksum(0) + netHeader := pkt.Network() + netHeader.SetDestinationAddress(address) // Only calculate the checksum if offloading isn't supported. - if r.Capabilities()&CapabilityTXChecksumOffload == 0 { + if r.RequiresTXTransportChecksum() { length := uint16(pkt.Size()) - uint16(len(pkt.NetworkHeader().View())) - xsum := r.PseudoHeaderChecksum(protocol, length) - for _, v := range pkt.Data.Views() { - xsum = header.Checksum(v, xsum) - } - udpHeader.SetChecksum(0) + xsum := header.PseudoHeaderChecksum(protocol, netHeader.SourceAddress(), netHeader.DestinationAddress(), length) + xsum = header.ChecksumVV(pkt.Data, xsum) udpHeader.SetChecksum(^udpHeader.CalculateChecksum(xsum)) } } - pkt.Network().SetDestinationAddress(rt.Addr) - // After modification, IPv4 packets need a valid checksum. if pkt.NetworkProtocolNumber == header.IPv4ProtocolNumber { netHeader := header.IPv4(pkt.NetworkHeader().View()) @@ -219,7 +173,7 @@ func (rt *RedirectTarget) Action(pkt *PacketBuffer, ct *ConnTrack, hook Hook, gs // Set up conection for matching NAT rule. Only the first // packet of the connection comes here. Other packets will be // manipulated in connection tracking. - if conn := ct.insertRedirectConn(pkt, hook, rt); conn != nil { + if conn := ct.insertRedirectConn(pkt, hook, rt.Port, address); conn != nil { ct.handlePacket(pkt, hook, gso, r) } default: diff --git a/pkg/tcpip/stack/iptables_types.go b/pkg/tcpip/stack/iptables_types.go index 7b3f3e88b..4b86c1be9 100644 --- a/pkg/tcpip/stack/iptables_types.go +++ b/pkg/tcpip/stack/iptables_types.go @@ -37,7 +37,6 @@ import ( // ----->[Prerouting]----->routing----->[Forward]---------[Postrouting]-----> type Hook uint -// These values correspond to values in include/uapi/linux/netfilter.h. const ( // Prerouting happens before a packet is routed to applications or to // be forwarded. @@ -86,8 +85,8 @@ type IPTables struct { mu sync.RWMutex // v4Tables and v6tables map tableIDs to tables. They hold builtin // tables only, not user tables. mu must be locked for accessing. - v4Tables [numTables]Table - v6Tables [numTables]Table + v4Tables [NumTables]Table + v6Tables [NumTables]Table // modified is whether tables have been modified at least once. It is // used to elide the iptables performance overhead for workloads that // don't utilize iptables. @@ -96,7 +95,7 @@ type IPTables struct { // priorities maps each hook to a list of table names. The order of the // list is the order in which each table should be visited for that // hook. It is immutable. - priorities [NumHooks][]tableID + priorities [NumHooks][]TableID connections ConnTrack @@ -104,6 +103,24 @@ type IPTables struct { reaperDone chan struct{} } +// VisitTargets traverses all the targets of all tables and replaces each with +// transform(target). +func (it *IPTables) VisitTargets(transform func(Target) Target) { + it.mu.Lock() + defer it.mu.Unlock() + + for tid := range it.v4Tables { + for i, rule := range it.v4Tables[tid].Rules { + it.v4Tables[tid].Rules[i].Target = transform(rule.Target) + } + } + for tid := range it.v6Tables { + for i, rule := range it.v6Tables[tid].Rules { + it.v6Tables[tid].Rules[i].Target = transform(rule.Target) + } + } +} + // A Table defines a set of chains and hooks into the network stack. // // It is a list of Rules, entry points (BuiltinChains), and error handlers @@ -169,7 +186,6 @@ type IPHeaderFilter struct { // CheckProtocol determines whether the Protocol field should be // checked during matching. - // TODO(gvisor.dev/issue/3549): Check this field during matching. CheckProtocol bool // Dst matches the destination IP address. @@ -309,23 +325,8 @@ type Matcher interface { Match(hook Hook, packet *PacketBuffer, interfaceName string) (matches bool, hotdrop bool) } -// A TargetID uniquely identifies a target. -type TargetID struct { - // Name is the target name as stored in the xt_entry_target struct. - Name string - - // NetworkProtocol is the protocol to which the target applies. - NetworkProtocol tcpip.NetworkProtocolNumber - - // Revision is the version of the target. - Revision uint8 -} - // A Target is the interface for taking an action for a packet. type Target interface { - // ID uniquely identifies the Target. - ID() TargetID - // Action takes an action on the packet and returns a verdict on how // traversal should (or should not) continue. If the return value is // Jump, it also returns the index of the rule to jump to. diff --git a/pkg/tcpip/stack/linkaddrcache.go b/pkg/tcpip/stack/linkaddrcache.go index 6f73a0ce4..c9b13cd0e 100644 --- a/pkg/tcpip/stack/linkaddrcache.go +++ b/pkg/tcpip/stack/linkaddrcache.go @@ -180,7 +180,7 @@ func (c *linkAddrCache) getOrCreateEntryLocked(k tcpip.FullAddress) *linkAddrEnt } // get reports any known link address for k. -func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { +func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, waker *sleep.Waker) (tcpip.LinkAddress, <-chan struct{}, *tcpip.Error) { if linkRes != nil { if addr, ok := linkRes.ResolveStaticAddress(k.Addr); ok { return addr, nil, nil @@ -221,7 +221,7 @@ func (c *linkAddrCache) get(k tcpip.FullAddress, linkRes LinkAddressResolver, lo } entry.done = make(chan struct{}) - go c.startAddressResolution(k, linkRes, localAddr, linkEP, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. + go c.startAddressResolution(k, linkRes, localAddr, nic, entry.done) // S/R-SAFE: link non-savable; wakers dropped synchronously. } return entry.linkAddr, entry.done, tcpip.ErrWouldBlock @@ -240,11 +240,11 @@ func (c *linkAddrCache) removeWaker(k tcpip.FullAddress, waker *sleep.Waker) { } } -func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, linkEP LinkEndpoint, done <-chan struct{}) { +func (c *linkAddrCache) startAddressResolution(k tcpip.FullAddress, linkRes LinkAddressResolver, localAddr tcpip.Address, nic NetworkInterface, done <-chan struct{}) { for i := 0; ; i++ { // Send link request, then wait for the timeout limit and check // whether the request succeeded. - linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, linkEP) + linkRes.LinkAddressRequest(k.Addr, localAddr, "" /* linkAddr */, nic) select { case now := <-time.After(c.resolutionTimeout): diff --git a/pkg/tcpip/stack/linkaddrcache_test.go b/pkg/tcpip/stack/linkaddrcache_test.go index 33806340e..d2e37f38d 100644 --- a/pkg/tcpip/stack/linkaddrcache_test.go +++ b/pkg/tcpip/stack/linkaddrcache_test.go @@ -49,8 +49,8 @@ type testLinkAddressResolver struct { onLinkAddressRequest func() } -func (r *testLinkAddressResolver) LinkAddressRequest(addr, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { - time.AfterFunc(r.delay, func() { r.fakeRequest(addr) }) +func (r *testLinkAddressResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { + time.AfterFunc(r.delay, func() { r.fakeRequest(targetAddr) }) if f := r.onLinkAddressRequest; f != nil { f() } diff --git a/pkg/tcpip/stack/neighbor_cache.go b/pkg/tcpip/stack/neighbor_cache.go index 4df288798..177bf5516 100644 --- a/pkg/tcpip/stack/neighbor_cache.go +++ b/pkg/tcpip/stack/neighbor_cache.go @@ -16,7 +16,6 @@ package stack import ( "fmt" - "time" "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/sync" @@ -68,7 +67,7 @@ var _ NUDHandler = (*neighborCache)(nil) // reset to state incomplete, and returned. If no matching entry exists and the // cache is not full, a new entry with state incomplete is allocated and // returned. -func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { +func (n *neighborCache) getOrCreateEntry(remoteAddr tcpip.Address, linkRes LinkAddressResolver) *neighborEntry { n.mu.Lock() defer n.mu.Unlock() @@ -84,7 +83,7 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li // The entry that needs to be created must be dynamic since all static // entries are directly added to the cache via addStaticEntry. - entry := newNeighborEntry(n.nic, remoteAddr, localAddr, n.state, linkRes) + entry := newNeighborEntry(n.nic, remoteAddr, n.state, linkRes) if n.dynamic.count == neighborCacheSize { e := n.dynamic.lru.Back() e.mu.Lock() @@ -111,28 +110,31 @@ func (n *neighborCache) getOrCreateEntry(remoteAddr, localAddr tcpip.Address, li // provided, it will be notified when address resolution is complete (success // or not). // +// If specified, the local address must be an address local to the interface the +// neighbor cache belongs to. The local address is the source address of a +// packet prompting NUD/link address resolution. +// // If address resolution is required, ErrNoLinkAddress and a notification // channel is returned for the top level caller to block. Channel is closed // once address resolution is complete (success or not). func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkAddressResolver, w *sleep.Waker) (NeighborEntry, <-chan struct{}, *tcpip.Error) { if linkAddr, ok := linkRes.ResolveStaticAddress(remoteAddr); ok { e := NeighborEntry{ - Addr: remoteAddr, - LocalAddr: localAddr, - LinkAddr: linkAddr, - State: Static, - UpdatedAt: time.Now(), + Addr: remoteAddr, + LinkAddr: linkAddr, + State: Static, + UpdatedAtNanos: 0, } return e, nil, nil } - entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) + entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() defer entry.mu.Unlock() switch s := entry.neigh.State; s { case Stale: - entry.handlePacketQueuedLocked() + entry.handlePacketQueuedLocked(localAddr) fallthrough case Reachable, Static, Delay, Probe: // As per RFC 4861 section 7.3.3: @@ -152,7 +154,7 @@ func (n *neighborCache) entry(remoteAddr, localAddr tcpip.Address, linkRes LinkA entry.done = make(chan struct{}) } - entry.handlePacketQueuedLocked() + entry.handlePacketQueuedLocked(localAddr) return entry.neigh, entry.done, tcpip.ErrWouldBlock case Failed: return entry.neigh, nil, tcpip.ErrNoLinkAddress @@ -207,7 +209,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd } else { // Static entry found with the same address but different link address. entry.neigh.LinkAddr = linkAddr - entry.dispatchChangeEventLocked(entry.neigh.State) + entry.dispatchChangeEventLocked() entry.mu.Unlock() return } @@ -220,8 +222,7 @@ func (n *neighborCache) addStaticEntry(addr tcpip.Address, linkAddr tcpip.LinkAd entry.mu.Unlock() } - entry := newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) - n.cache[addr] = entry + n.cache[addr] = newStaticNeighborEntry(n.nic, addr, linkAddr, n.state) } // removeEntryLocked removes the specified entry from the neighbor cache. @@ -292,8 +293,8 @@ func (n *neighborCache) setConfig(config NUDConfigurations) { // HandleProbe implements NUDHandler.HandleProbe by following the logic defined // in RFC 4861 section 7.2.3. Validation of the probe is expected to be handled // by the caller. -func (n *neighborCache) HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { - entry := n.getOrCreateEntry(remoteAddr, localAddr, linkRes) +func (n *neighborCache) HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) { + entry := n.getOrCreateEntry(remoteAddr, linkRes) entry.mu.Lock() entry.handleProbeLocked(remoteLinkAddr) entry.mu.Unlock() diff --git a/pkg/tcpip/stack/neighbor_cache_test.go b/pkg/tcpip/stack/neighbor_cache_test.go index fcd54ed83..ed33418f3 100644 --- a/pkg/tcpip/stack/neighbor_cache_test.go +++ b/pkg/tcpip/stack/neighbor_cache_test.go @@ -61,23 +61,20 @@ const ( ) // entryDiffOpts returns the options passed to cmp.Diff to compare neighbor -// entries. The UpdatedAt field is ignored due to a lack of a deterministic -// method to predict the time that an event will be dispatched. +// entries. The UpdatedAtNanos field is ignored due to a lack of a +// deterministic method to predict the time that an event will be dispatched. func entryDiffOpts() []cmp.Option { return []cmp.Option{ - cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"), } } // entryDiffOptsWithSort is like entryDiffOpts but also includes an option to // sort slices of entries for cases where ordering must be ignored. func entryDiffOptsWithSort() []cmp.Option { - return []cmp.Option{ - cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAt"), - cmpopts.SortSlices(func(a, b NeighborEntry) bool { - return strings.Compare(string(a.Addr), string(b.Addr)) < 0 - }), - } + return append(entryDiffOpts(), cmpopts.SortSlices(func(a, b NeighborEntry) bool { + return strings.Compare(string(a.Addr), string(b.Addr)) < 0 + })) } func newTestNeighborCache(nudDisp NUDDispatcher, config NUDConfigurations, clock tcpip.Clock) *neighborCache { @@ -128,9 +125,8 @@ func newTestEntryStore() *testEntryStore { linkAddr := toLinkAddress(i) store.entriesMap[addr] = NeighborEntry{ - Addr: addr, - LocalAddr: testEntryLocalAddr, - LinkAddr: linkAddr, + Addr: addr, + LinkAddr: linkAddr, } } return store @@ -195,10 +191,10 @@ type testNeighborResolver struct { var _ LinkAddressResolver = (*testNeighborResolver)(nil) -func (r *testNeighborResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (r *testNeighborResolver) LinkAddressRequest(targetAddr, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { // Delay handling the request to emulate network latency. r.clock.AfterFunc(r.delay, func() { - r.fakeRequest(addr) + r.fakeRequest(targetAddr) }) // Execute post address resolution action, if available. @@ -294,9 +290,8 @@ func TestNeighborCacheEntry(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -305,15 +300,19 @@ func TestNeighborCacheEntry(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -324,8 +323,8 @@ func TestNeighborCacheEntry(t *testing.T) { t.Fatalf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil { + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } // No more events should have been dispatched. @@ -354,9 +353,9 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -365,15 +364,19 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -391,9 +394,11 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -404,8 +409,8 @@ func TestNeighborCacheRemoveEntry(t *testing.T) { } } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } } @@ -452,8 +457,8 @@ func (c *testContext) overflowCache(opts overflowOptions) error { if !ok { return fmt.Errorf("c.store.entry(%d) not found", i) } - if _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil); err != tcpip.ErrWouldBlock { - return fmt.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + return fmt.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) @@ -470,23 +475,29 @@ func (c *testContext) overflowCache(opts overflowOptions) error { wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestRemoved, NICID: 1, - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, }) } wantEvents = append(wantEvents, testEntryEventInfo{ EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, testEntryEventInfo{ EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }) c.nudDisp.mu.Lock() @@ -508,10 +519,9 @@ func (c *testContext) overflowCache(opts overflowOptions) error { return fmt.Errorf("c.store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -564,24 +574,27 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(c.neigh.config().RetransmitTimer) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -600,9 +613,11 @@ func TestNeighborCacheRemoveEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -640,9 +655,11 @@ func TestNeighborCacheDuplicateStaticEntryWithSameLinkAddress(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -682,9 +699,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -703,9 +722,11 @@ func TestNeighborCacheDuplicateStaticEntryWithDifferentLinkAddress(t *testing.T) { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -740,9 +761,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -760,9 +783,11 @@ func TestNeighborCacheRemoveStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -800,24 +825,27 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -836,16 +864,20 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: staticLinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -861,10 +893,9 @@ func TestNeighborCacheOverwriteWithStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: staticLinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: staticLinkAddr, + State: Static, }, }, } @@ -896,12 +927,12 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, _ = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _ = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) } clock.Advance(typicalLatency) @@ -913,7 +944,7 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { id, ok := s.Fetch(false /* block */) if !ok { - t.Errorf("expected waker to be notified after neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Errorf("expected waker to be notified after neigh.entry(%s, '', _, _)", entry.Addr) } if id != wakerID { t.Errorf("got s.Fetch(false) = %d, want = %d", id, wakerID) @@ -923,15 +954,19 @@ func TestNeighborCacheNotifiesWaker(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -964,12 +999,12 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, &w) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, &w) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, _) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, _) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh == nil { - t.Fatalf("expected done channel from neigh.entry(%s, %s, _, _)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected done channel from neigh.entry(%s, '', _, _)", entry.Addr) } // Remove the waker before the neighbor cache has the opportunity to send a @@ -991,15 +1026,19 @@ func TestNeighborCacheRemoveWaker(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1028,10 +1067,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { t.Errorf("unexpected error from c.neigh.entry(%s, \"\", _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: entry.LinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { t.Errorf("c.neigh.entry(%s, \"\", _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) @@ -1041,9 +1079,11 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Static, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, + }, }, } c.nudDisp.mu.Lock() @@ -1058,10 +1098,9 @@ func TestNeighborCacheAddStaticEntryThenOverflow(t *testing.T) { startAtEntryIndex: 1, wantStaticEntries: []NeighborEntry{ { - Addr: entry.Addr, - LocalAddr: "", // static entries don't need a local address - LinkAddr: entry.LinkAddr, - State: Static, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Static, }, }, } @@ -1089,9 +1128,8 @@ func TestNeighborCacheClear(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) @@ -1099,15 +1137,19 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1126,9 +1168,11 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestAdded, NICID: 1, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, }, } nudDisp.mu.Lock() @@ -1149,16 +1193,20 @@ func TestNeighborCacheClear(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestRemoved, NICID: 1, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Static, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Static, + }, }, } nudDisp.mu.Lock() @@ -1185,24 +1233,27 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { if !ok { t.Fatalf("c.store.entry(0) not found") } - _, _, err := c.neigh.entry(entry.Addr, entry.LocalAddr, c.linkRes, nil) - if err != tcpip.ErrWouldBlock { - t.Errorf("got c.neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := c.neigh.entry(entry.Addr, "", c.linkRes, nil); err != tcpip.ErrWouldBlock { + t.Errorf("got c.neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } c.clock.Advance(typicalLatency) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -1220,9 +1271,11 @@ func TestNeighborCacheClearThenOverflow(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } c.nudDisp.mu.Lock() @@ -1274,29 +1327,33 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1312,9 +1369,8 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { for i := neighborCacheSize; i < store.size(); i++ { // Periodically refresh the frequently used entry if i%(neighborCacheSize/2) == 0 { - _, _, err := neigh.entry(frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, linkRes, nil) - if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", frequentlyUsedEntry.Addr, frequentlyUsedEntry.LocalAddr, err) + if _, _, err := neigh.entry(frequentlyUsedEntry.Addr, "", linkRes, nil); err != nil { + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", frequentlyUsedEntry.Addr, err) } } @@ -1322,15 +1378,15 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { if !ok { t.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Errorf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } // An entry should have been removed, as per the LRU eviction strategy @@ -1342,22 +1398,28 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { { EventType: entryTestRemoved, NICID: 1, - Addr: removedEntry.Addr, - LinkAddr: removedEntry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: removedEntry.Addr, + LinkAddr: removedEntry.LinkAddr, + State: Reachable, + }, }, { EventType: entryTestAdded, NICID: 1, - Addr: entry.Addr, - State: Incomplete, + Entry: NeighborEntry{ + Addr: entry.Addr, + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: 1, - Addr: entry.Addr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Entry: NeighborEntry{ + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1374,10 +1436,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { // have to be sorted before comparison. wantUnsortedEntries := []NeighborEntry{ { - Addr: frequentlyUsedEntry.Addr, - LocalAddr: frequentlyUsedEntry.LocalAddr, - LinkAddr: frequentlyUsedEntry.LinkAddr, - State: Reachable, + Addr: frequentlyUsedEntry.Addr, + LinkAddr: frequentlyUsedEntry.LinkAddr, + State: Reachable, }, } @@ -1387,10 +1448,9 @@ func TestNeighborCacheKeepFrequentlyUsed(t *testing.T) { t.Fatalf("store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -1430,9 +1490,8 @@ func TestNeighborCacheConcurrent(t *testing.T) { wg.Add(1) go func(entry NeighborEntry) { defer wg.Done() - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) - if err != nil && err != tcpip.ErrWouldBlock { - t.Errorf("got neigh.entry(%s, %s, _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, entry.LocalAddr, e, err, tcpip.ErrWouldBlock) + if e, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != nil && err != tcpip.ErrWouldBlock { + t.Errorf("got neigh.entry(%s, '', _, nil) = (%+v, _, %s), want (_, _, nil) or (_, _, %s)", entry.Addr, e, err, tcpip.ErrWouldBlock) } }(entry) } @@ -1456,10 +1515,9 @@ func TestNeighborCacheConcurrent(t *testing.T) { t.Errorf("store.entry(%d) not found", i) } wantEntry := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } wantUnsortedEntries = append(wantUnsortedEntries, wantEntry) } @@ -1488,37 +1546,36 @@ func TestNeighborCacheReplace(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) select { case <-doneCh: default: - t.Fatalf("expected notification from done channel returned by neigh.entry(%s, %s, _, nil)", entry.Addr, entry.LocalAddr) + t.Fatalf("expected notification from done channel returned by neigh.entry(%s, '', _, nil)", entry.Addr) } // Verify the entry exists { - e, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } if doneCh != nil { - t.Errorf("unexpected done channel from neigh.entry(%s, %s, _, nil): %v", entry.Addr, entry.LocalAddr, doneCh) + t.Errorf("unexpected done channel from neigh.entry(%s, '', _, nil): %v", entry.Addr, doneCh) } if t.Failed() { t.FailNow() } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LinkAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } @@ -1542,37 +1599,35 @@ func TestNeighborCacheReplace(t *testing.T) { // // Verify the entry's new link address and the new state. { - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Fatalf("neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: updatedLinkAddr, - State: Delay, + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Delay, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } clock.Advance(config.DelayFirstProbeTime + typicalLatency) } // Verify that the neighbor is now reachable. { - e, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + e, _, err := neigh.entry(entry.Addr, "", linkRes, nil) clock.Advance(typicalLatency) if err != nil { - t.Errorf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Errorf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: updatedLinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: updatedLinkAddr, + State: Reachable, } if diff := cmp.Diff(e, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } } } @@ -1601,35 +1656,34 @@ func TestNeighborCacheResolutionFailed(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } clock.Advance(typicalLatency) - got, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + got, _, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", entry.Addr, entry.LocalAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", entry.Addr, err) } want := NeighborEntry{ - Addr: entry.Addr, - LocalAddr: entry.LocalAddr, - LinkAddr: entry.LinkAddr, - State: Reachable, + Addr: entry.Addr, + LinkAddr: entry.LinkAddr, + State: Reachable, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", entry.Addr, entry.LocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", entry.Addr, diff) } // Verify that address resolution for an unknown address returns ErrNoLinkAddress before := atomic.LoadUint32(&requestCount) entry.Addr += "2" - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.DelayFirstProbeTime + typicalLatency*time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) } maxAttempts := neigh.config().MaxUnicastProbes @@ -1659,13 +1713,13 @@ func TestNeighborCacheResolutionTimeout(t *testing.T) { if !ok { t.Fatalf("store.entry(0) not found") } - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrWouldBlock { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrWouldBlock { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } waitFor := config.RetransmitTimer * time.Duration(config.MaxMulticastProbes) clock.Advance(waitFor) - if _, _, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil); err != tcpip.ErrNoLinkAddress { - t.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrNoLinkAddress) + if _, _, err := neigh.entry(entry.Addr, "", linkRes, nil); err != tcpip.ErrNoLinkAddress { + t.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrNoLinkAddress) } } @@ -1683,18 +1737,17 @@ func TestNeighborCacheStaticResolution(t *testing.T) { delay: typicalLatency, } - got, _, err := neigh.entry(testEntryBroadcastAddr, testEntryLocalAddr, linkRes, nil) + got, _, err := neigh.entry(testEntryBroadcastAddr, "", linkRes, nil) if err != nil { - t.Fatalf("unexpected error from neigh.entry(%s, %s, _, nil): %s", testEntryBroadcastAddr, testEntryLocalAddr, err) + t.Fatalf("unexpected error from neigh.entry(%s, '', _, nil): %s", testEntryBroadcastAddr, err) } want := NeighborEntry{ - Addr: testEntryBroadcastAddr, - LocalAddr: testEntryLocalAddr, - LinkAddr: testEntryBroadcastLinkAddr, - State: Static, + Addr: testEntryBroadcastAddr, + LinkAddr: testEntryBroadcastLinkAddr, + State: Static, } if diff := cmp.Diff(got, want, entryDiffOpts()...); diff != "" { - t.Errorf("neigh.entry(%s, %s, _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, testEntryLocalAddr, diff) + t.Errorf("neigh.entry(%s, '', _, nil) mismatch (-got, +want):\n%s", testEntryBroadcastAddr, diff) } } @@ -1719,9 +1772,9 @@ func BenchmarkCacheClear(b *testing.B) { if !ok { b.Fatalf("store.entry(%d) not found", i) } - _, doneCh, err := neigh.entry(entry.Addr, entry.LocalAddr, linkRes, nil) + _, doneCh, err := neigh.entry(entry.Addr, "", linkRes, nil) if err != tcpip.ErrWouldBlock { - b.Fatalf("got neigh.entry(%s, %s, _, nil) = %v, want = %s", entry.Addr, entry.LocalAddr, err, tcpip.ErrWouldBlock) + b.Fatalf("got neigh.entry(%s, '', _, nil) = %v, want = %s", entry.Addr, err, tcpip.ErrWouldBlock) } if doneCh != nil { <-doneCh diff --git a/pkg/tcpip/stack/neighbor_entry.go b/pkg/tcpip/stack/neighbor_entry.go index be61a21af..493e48031 100644 --- a/pkg/tcpip/stack/neighbor_entry.go +++ b/pkg/tcpip/stack/neighbor_entry.go @@ -24,13 +24,18 @@ import ( "gvisor.dev/gvisor/pkg/tcpip/header" ) +const ( + // immediateDuration is a duration of zero for scheduling work that needs to + // be done immediately but asynchronously to avoid deadlock. + immediateDuration time.Duration = 0 +) + // NeighborEntry describes a neighboring device in the local network. type NeighborEntry struct { - Addr tcpip.Address - LocalAddr tcpip.Address - LinkAddr tcpip.LinkAddress - State NeighborState - UpdatedAt time.Time + Addr tcpip.Address + LinkAddr tcpip.LinkAddress + State NeighborState + UpdatedAtNanos int64 } // NeighborState defines the state of a NeighborEntry within the Neighbor @@ -106,35 +111,35 @@ type neighborEntry struct { // state, Unknown. Transition out of Unknown by calling either // `handlePacketQueuedLocked` or `handleProbeLocked` on the newly created // neighborEntry. -func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, localAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { +func newNeighborEntry(nic *NIC, remoteAddr tcpip.Address, nudState *NUDState, linkRes LinkAddressResolver) *neighborEntry { return &neighborEntry{ nic: nic, linkRes: linkRes, nudState: nudState, neigh: NeighborEntry{ - Addr: remoteAddr, - LocalAddr: localAddr, - State: Unknown, + Addr: remoteAddr, + State: Unknown, }, } } -// newStaticNeighborEntry creates a neighbor cache entry starting at the Static -// state. The entry can only transition out of Static by directly calling -// `setStateLocked`. +// newStaticNeighborEntry creates a neighbor cache entry starting at the +// Static state. The entry can only transition out of Static by directly +// calling `setStateLocked`. func newStaticNeighborEntry(nic *NIC, addr tcpip.Address, linkAddr tcpip.LinkAddress, state *NUDState) *neighborEntry { + entry := NeighborEntry{ + Addr: addr, + LinkAddr: linkAddr, + State: Static, + UpdatedAtNanos: nic.stack.clock.NowNanoseconds(), + } if nic.stack.nudDisp != nil { - nic.stack.nudDisp.OnNeighborAdded(nic.id, addr, linkAddr, Static, time.Now()) + nic.stack.nudDisp.OnNeighborAdded(nic.id, entry) } return &neighborEntry{ nic: nic, nudState: state, - neigh: NeighborEntry{ - Addr: addr, - LinkAddr: linkAddr, - State: Static, - UpdatedAt: time.Now(), - }, + neigh: entry, } } @@ -165,17 +170,17 @@ func (e *neighborEntry) notifyWakersLocked() { // dispatchAddEventLocked signals to stack's NUD Dispatcher that the entry has // been added. -func (e *neighborEntry) dispatchAddEventLocked(nextState NeighborState) { +func (e *neighborEntry) dispatchAddEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborAdded(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + nudDisp.OnNeighborAdded(e.nic.id, e.neigh) } } // dispatchChangeEventLocked signals to stack's NUD Dispatcher that the entry // has changed state or link-layer address. -func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { +func (e *neighborEntry) dispatchChangeEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborChanged(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, nextState, time.Now()) + nudDisp.OnNeighborChanged(e.nic.id, e.neigh) } } @@ -183,7 +188,7 @@ func (e *neighborEntry) dispatchChangeEventLocked(nextState NeighborState) { // has been removed. func (e *neighborEntry) dispatchRemoveEventLocked() { if nudDisp := e.nic.stack.nudDisp; nudDisp != nil { - nudDisp.OnNeighborRemoved(e.nic.id, e.neigh.Addr, e.neigh.LinkAddr, e.neigh.State, time.Now()) + nudDisp.OnNeighborRemoved(e.nic.id, e.neigh) } } @@ -201,68 +206,24 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { prev := e.neigh.State e.neigh.State = next - e.neigh.UpdatedAt = time.Now() + e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() config := e.nudState.Config() switch next { case Incomplete: - var retryCounter uint32 - var sendMulticastProbe func() - - sendMulticastProbe = func() { - if retryCounter == config.MaxMulticastProbes { - // "If no Neighbor Advertisement is received after - // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. - // The sender MUST return ICMP destination unreachable indications with - // code 3 (Address Unreachable) for each packet queued awaiting address - // resolution." - RFC 4861 section 7.2.2 - // - // There is no need to send an ICMP destination unreachable indication - // since the failure to resolve the address is expected to only occur - // on this node. Thus, redirecting traffic is currently not supported. - // - // "If the error occurs on a node other than the node originating the - // packet, an ICMP error message is generated. If the error occurs on - // the originating node, an implementation is not required to actually - // create and send an ICMP error packet to the source, as long as the - // upper-layer sender is notified through an appropriate mechanism - // (e.g. return value from a procedure call). Note, however, that an - // implementation may find it convenient in some cases to return errors - // to the sender by taking the offending packet, generating an ICMP - // error message, and then delivering it (locally) through the generic - // error-handling routines.' - RFC 4861 section 2.1 - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, "", e.nic.LinkEndpoint); err != nil { - // There is no need to log the error here; the NUD implementation may - // assume a working link. A valid link should be the responsibility of - // the NIC/stack.LinkEndpoint. - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - - retryCounter++ - e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) - e.job.Schedule(config.RetransmitTimer) - } - - sendMulticastProbe() + panic(fmt.Sprintf("should never transition to Incomplete with setStateLocked; neigh = %#v, prev state = %s", e.neigh, prev)) case Reachable: e.job = e.nic.stack.newJob(&e.mu, func() { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() }) e.job.Schedule(e.nudState.ReachableTime()) case Delay: e.job = e.nic.stack.newJob(&e.mu, func() { - e.dispatchChangeEventLocked(Probe) e.setStateLocked(Probe) + e.dispatchChangeEventLocked() }) e.job.Schedule(config.DelayFirstProbeTime) @@ -277,24 +238,23 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { return } - if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, e.neigh.LocalAddr, e.neigh.LinkAddr, e.nic.LinkEndpoint); err != nil { + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, "" /* localAddr */, e.neigh.LinkAddr, e.nic); err != nil { e.dispatchRemoveEventLocked() e.setStateLocked(Failed) return } retryCounter++ - if retryCounter == config.MaxUnicastProbes { - e.dispatchRemoveEventLocked() - e.setStateLocked(Failed) - return - } - e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) e.job.Schedule(config.RetransmitTimer) } - sendUnicastProbe() + // Send a probe in another gorountine to free this thread of execution + // for finishing the state transition. This is necessary to avoid + // deadlock where sending and processing probes are done synchronously, + // such as loopback and integration tests. + e.job = e.nic.stack.newJob(&e.mu, sendUnicastProbe) + e.job.Schedule(immediateDuration) case Failed: e.notifyWakersLocked() @@ -315,15 +275,77 @@ func (e *neighborEntry) setStateLocked(next NeighborState) { // being queued for outgoing transmission. // // Follows the logic defined in RFC 4861 section 7.3.3. -func (e *neighborEntry) handlePacketQueuedLocked() { +func (e *neighborEntry) handlePacketQueuedLocked(localAddr tcpip.Address) { switch e.neigh.State { case Unknown: - e.dispatchAddEventLocked(Incomplete) - e.setStateLocked(Incomplete) + e.neigh.State = Incomplete + e.neigh.UpdatedAtNanos = e.nic.stack.clock.NowNanoseconds() + + e.dispatchAddEventLocked() + + config := e.nudState.Config() + + var retryCounter uint32 + var sendMulticastProbe func() + + sendMulticastProbe = func() { + if retryCounter == config.MaxMulticastProbes { + // "If no Neighbor Advertisement is received after + // MAX_MULTICAST_SOLICIT solicitations, address resolution has failed. + // The sender MUST return ICMP destination unreachable indications with + // code 3 (Address Unreachable) for each packet queued awaiting address + // resolution." - RFC 4861 section 7.2.2 + // + // There is no need to send an ICMP destination unreachable indication + // since the failure to resolve the address is expected to only occur + // on this node. Thus, redirecting traffic is currently not supported. + // + // "If the error occurs on a node other than the node originating the + // packet, an ICMP error message is generated. If the error occurs on + // the originating node, an implementation is not required to actually + // create and send an ICMP error packet to the source, as long as the + // upper-layer sender is notified through an appropriate mechanism + // (e.g. return value from a procedure call). Note, however, that an + // implementation may find it convenient in some cases to return errors + // to the sender by taking the offending packet, generating an ICMP + // error message, and then delivering it (locally) through the generic + // error-handling routines.' - RFC 4861 section 2.1 + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + // As per RFC 4861 section 7.2.2: + // + // If the source address of the packet prompting the solicitation is the + // same as one of the addresses assigned to the outgoing interface, that + // address SHOULD be placed in the IP Source Address of the outgoing + // solicitation. + // + if err := e.linkRes.LinkAddressRequest(e.neigh.Addr, localAddr, "", e.nic); err != nil { + // There is no need to log the error here; the NUD implementation may + // assume a working link. A valid link should be the responsibility of + // the NIC/stack.LinkEndpoint. + e.dispatchRemoveEventLocked() + e.setStateLocked(Failed) + return + } + + retryCounter++ + e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job.Schedule(config.RetransmitTimer) + } + + // Send a probe in another gorountine to free this thread of execution + // for finishing the state transition. This is necessary to avoid + // deadlock where sending and processing probes are done synchronously, + // such as loopback and integration tests. + e.job = e.nic.stack.newJob(&e.mu, sendMulticastProbe) + e.job.Schedule(immediateDuration) case Stale: - e.dispatchChangeEventLocked(Delay) e.setStateLocked(Delay) + e.dispatchChangeEventLocked() case Incomplete, Reachable, Delay, Probe, Static, Failed: // Do nothing @@ -345,21 +367,21 @@ func (e *neighborEntry) handleProbeLocked(remoteLinkAddr tcpip.LinkAddress) { switch e.neigh.State { case Unknown, Incomplete, Failed: e.neigh.LinkAddr = remoteLinkAddr - e.dispatchAddEventLocked(Stale) e.setStateLocked(Stale) e.notifyWakersLocked() + e.dispatchAddEventLocked() case Reachable, Delay, Probe: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } case Stale: if e.neigh.LinkAddr != remoteLinkAddr { e.neigh.LinkAddr = remoteLinkAddr - e.dispatchChangeEventLocked(Stale) + e.dispatchChangeEventLocked() } case Static: @@ -393,12 +415,11 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla e.neigh.LinkAddr = linkAddr if flags.Solicited { - e.dispatchChangeEventLocked(Reachable) e.setStateLocked(Reachable) } else { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) } + e.dispatchChangeEventLocked() e.isRouter = flags.IsRouter e.notifyWakersLocked() @@ -411,8 +432,8 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla if isLinkAddrDifferent { if !flags.Override { if e.neigh.State == Reachable { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } break } @@ -421,23 +442,24 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla if !flags.Solicited { if e.neigh.State != Stale { - e.dispatchChangeEventLocked(Stale) e.setStateLocked(Stale) + e.dispatchChangeEventLocked() } else { // Notify the LinkAddr change, even though NUD state hasn't changed. - e.dispatchChangeEventLocked(e.neigh.State) + e.dispatchChangeEventLocked() } break } } if flags.Solicited && (flags.Override || !isLinkAddrDifferent) { - if e.neigh.State != Reachable { - e.dispatchChangeEventLocked(Reachable) - } + wasReachable := e.neigh.State == Reachable // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) e.notifyWakersLocked() + if !wasReachable { + e.dispatchChangeEventLocked() + } } if e.isRouter && !flags.IsRouter && header.IsV6UnicastAddress(e.neigh.Addr) { @@ -475,11 +497,12 @@ func (e *neighborEntry) handleConfirmationLocked(linkAddr tcpip.LinkAddress, fla func (e *neighborEntry) handleUpperLevelConfirmationLocked() { switch e.neigh.State { case Reachable, Stale, Delay, Probe: - if e.neigh.State != Reachable { - e.dispatchChangeEventLocked(Reachable) - // Set state to Reachable again to refresh timers. - } + wasReachable := e.neigh.State == Reachable + // Set state to Reachable again to refresh timers. e.setStateLocked(Reachable) + if !wasReachable { + e.dispatchChangeEventLocked() + } case Unknown, Incomplete, Failed, Static: // Do nothing diff --git a/pkg/tcpip/stack/neighbor_entry_test.go b/pkg/tcpip/stack/neighbor_entry_test.go index 3ee2a3b31..c2b763325 100644 --- a/pkg/tcpip/stack/neighbor_entry_test.go +++ b/pkg/tcpip/stack/neighbor_entry_test.go @@ -47,24 +47,27 @@ const ( entryTestNetDefaultMTU = 65536 ) +// runImmediatelyScheduledJobs runs all jobs scheduled to run at the current +// time. +func runImmediatelyScheduledJobs(clock *faketime.ManualClock) { + clock.Advance(immediateDuration) +} + // eventDiffOpts are the options passed to cmp.Diff to compare entry events. -// The UpdatedAt field is ignored due to a lack of a deterministic method to -// predict the time that an event will be dispatched. +// The UpdatedAtNanos field is ignored due to a lack of a deterministic method +// to predict the time that an event will be dispatched. func eventDiffOpts() []cmp.Option { return []cmp.Option{ - cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), + cmpopts.IgnoreFields(NeighborEntry{}, "UpdatedAtNanos"), } } // eventDiffOptsWithSort is like eventDiffOpts but also includes an option to // sort slices of events for cases where ordering must be ignored. func eventDiffOptsWithSort() []cmp.Option { - return []cmp.Option{ - cmpopts.IgnoreFields(testEntryEventInfo{}, "UpdatedAt"), - cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { - return strings.Compare(string(a.Addr), string(b.Addr)) < 0 - }), - } + return append(eventDiffOpts(), cmpopts.SortSlices(func(a, b testEntryEventInfo) bool { + return strings.Compare(string(a.Entry.Addr), string(b.Entry.Addr)) < 0 + })) } // The following unit tests exercise every state transition and verify its @@ -125,14 +128,11 @@ func (t testEntryEventType) String() string { type testEntryEventInfo struct { EventType testEntryEventType NICID tcpip.NICID - Addr tcpip.Address - LinkAddr tcpip.LinkAddress - State NeighborState - UpdatedAt time.Time + Entry NeighborEntry } func (e testEntryEventInfo) String() string { - return fmt.Sprintf("%s event for NIC #%d, addr=%q, linkAddr=%q, state=%q", e.EventType, e.NICID, e.Addr, e.LinkAddr, e.State) + return fmt.Sprintf("%s event for NIC #%d, %#v", e.EventType, e.NICID, e.Entry) } // testNUDDispatcher implements NUDDispatcher to validate the dispatching of @@ -150,36 +150,27 @@ func (d *testNUDDispatcher) queueEvent(e testEntryEventInfo) { d.events = append(d.events, e) } -func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborAdded(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestAdded, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } -func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborChanged(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestChanged, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } -func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, addr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) { +func (d *testNUDDispatcher) OnNeighborRemoved(nicID tcpip.NICID, entry NeighborEntry) { d.queueEvent(testEntryEventInfo{ EventType: entryTestRemoved, NICID: nicID, - Addr: addr, - LinkAddr: linkAddr, - State: state, - UpdatedAt: updatedAt, + Entry: entry, }) } @@ -202,9 +193,9 @@ func (p entryTestProbeInfo) String() string { // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts // to the local network if linkAddr is the zero value. -func (r *entryTestLinkResolver) LinkAddressRequest(addr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error { +func (r *entryTestLinkResolver) LinkAddressRequest(targetAddr, localAddr tcpip.Address, linkAddr tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { p := entryTestProbeInfo{ - RemoteAddress: addr, + RemoteAddress: targetAddr, RemoteLinkAddress: linkAddr, LocalAddress: localAddr, } @@ -245,7 +236,7 @@ func entryTestSetup(c NUDConfigurations) (*neighborEntry, *testNUDDispatcher, *e rng := rand.New(rand.NewSource(time.Now().UnixNano())) nudState := NewNUDState(c, rng) linkRes := entryTestLinkResolver{} - entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, entryTestAddr2 /* localAddr */, nudState, &linkRes) + entry := newNeighborEntry(&nic, entryTestAddr1 /* remoteAddr */, nudState, &linkRes) // Stub out the neighbor cache to verify deletion from the cache. nic.neigh = &neighborCache{ @@ -323,15 +314,16 @@ func TestEntryUnknownToUnknownWhenConfirmationWithUnknownAddress(t *testing.T) { func TestEntryUnknownToIncomplete(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -350,9 +342,11 @@ func TestEntryUnknownToIncomplete(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } { @@ -367,7 +361,7 @@ func TestEntryUnknownToIncomplete(t *testing.T) { func TestEntryUnknownToStale(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) @@ -377,6 +371,7 @@ func TestEntryUnknownToStale(t *testing.T) { e.mu.Unlock() // No probes should have been sent. + runImmediatelyScheduledJobs(clock) linkRes.mu.Lock() diff := cmp.Diff(linkRes.probes, []entryTestProbeInfo(nil)) linkRes.mu.Unlock() @@ -388,9 +383,11 @@ func TestEntryUnknownToStale(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -406,11 +403,11 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } - updatedAt := e.neigh.UpdatedAt + updatedAtNanos := e.neigh.UpdatedAtNanos e.mu.Unlock() clock.Advance(c.RetransmitTimer) @@ -437,7 +434,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { } e.mu.Lock() - if got, want := e.neigh.UpdatedAt, updatedAt; got != want { + if got, want := e.neigh.UpdatedAtNanos, updatedAtNanos; got != want { t.Errorf("got e.neigh.UpdatedAt = %q, want = %q", got, want) } e.mu.Unlock() @@ -468,16 +465,20 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } nudDisp.mu.Lock() @@ -487,7 +488,7 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { nudDisp.mu.Unlock() e.mu.Lock() - if got, notWant := e.neigh.UpdatedAt, updatedAt; got == notWant { + if got, notWant := e.neigh.UpdatedAtNanos, updatedAtNanos; got == notWant { t.Errorf("expected e.neigh.UpdatedAt to change, got = %q", got) } e.mu.Unlock() @@ -495,23 +496,16 @@ func TestEntryIncompleteToIncompleteDoesNotChangeUpdatedAt(t *testing.T) { func TestEntryIncompleteToReachable(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -526,20 +520,35 @@ func TestEntryIncompleteToReachable(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -555,7 +564,7 @@ func TestEntryIncompleteToReachable(t *testing.T) { // to Reachable. func TestEntryAddsAndClearsWakers(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) w := sleep.Waker{} s := sleep.Sleeper{} @@ -563,7 +572,25 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { defer s.Done() e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() if got := e.wakers; got != nil { t.Errorf("got e.wakers = %v, want = nil", got) } @@ -587,34 +614,24 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -626,26 +643,16 @@ func TestEntryAddsAndClearsWakers(t *testing.T) { func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: true, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.isRouter, true; got != want { - t.Errorf("got e.isRouter = %t, want = %t", got, want) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -659,20 +666,38 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { } linkRes.mu.Unlock() + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: true, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if !e.isRouter { + t.Errorf("got e.isRouter = %t, want = true", e.isRouter) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -684,23 +709,16 @@ func TestEntryIncompleteToReachableWithRouterFlag(t *testing.T) { func TestEntryIncompleteToStale(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Incomplete; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Incomplete { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Incomplete) } e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -715,20 +733,35 @@ func TestEntryIncompleteToStale(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -744,7 +777,7 @@ func TestEntryIncompleteToFailed(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Incomplete; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } @@ -783,16 +816,20 @@ func TestEntryIncompleteToFailed(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, } nudDisp.mu.Lock() @@ -817,12 +854,30 @@ func (*testLocker) Unlock() {} func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) ipv6EP := e.nic.networkEndpoints[header.IPv6ProtocolNumber].(*testIPv6Endpoint) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: true, Override: false, @@ -848,34 +903,24 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -893,27 +938,13 @@ func TestEntryStaysReachableWhenConfirmationWithRouterFlag(t *testing.T) { func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr1) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -928,20 +959,42 @@ func TestEntryStaysReachableWhenProbeWithSameAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleProbeLocked(entryTestLinkAddr1) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -961,17 +1014,10 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -986,29 +1032,46 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1026,24 +1089,13 @@ func TestEntryReachableToStaleWhenTimeout(t *testing.T) { func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1058,27 +1110,48 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleProbeLocked(entryTestLinkAddr2) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1086,38 +1159,17 @@ func TestEntryReachableToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1132,27 +1184,52 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1160,38 +1237,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddress(t *testing.T) t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: true, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1206,27 +1262,52 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: true, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1234,37 +1315,17 @@ func TestEntryReachableToStaleWhenConfirmationWithDifferentAddressAndOverride(t t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr1) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1279,20 +1340,42 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleProbeLocked(entryTestLinkAddr1) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1304,31 +1387,13 @@ func TestEntryStaysStaleWhenProbeWithSameAddress(t *testing.T) { func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1343,27 +1408,55 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1375,10 +1468,28 @@ func TestEntryStaleToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) + + e.mu.Lock() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } e.mu.Lock() - e.handlePacketQueuedLocked() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, @@ -1400,41 +1511,33 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, } nudDisp.mu.Lock() @@ -1446,31 +1549,13 @@ func TestEntryStaleToReachableWhenSolicitedConfirmationWithoutAddress(t *testing func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1485,27 +1570,55 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1517,27 +1630,13 @@ func TestEntryStaleToStaleWhenOverrideConfirmation(t *testing.T) { func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1552,27 +1651,51 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handleProbeLocked(entryTestLinkAddr2) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1584,24 +1707,13 @@ func TestEntryStaleToStaleWhenProbeUpdateAddress(t *testing.T) { func TestEntryStaleToDelay(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1616,27 +1728,48 @@ func TestEntryStaleToDelay(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, } nudDisp.mu.Lock() @@ -1656,22 +1789,10 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleUpperLevelConfirmationLocked() - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1686,43 +1807,68 @@ func TestEntryDelayToReachableWhenUpperLevelConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.Advance(c.BaseReachableTime) + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleUpperLevelConfirmationLocked() + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + e.mu.Unlock() + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1743,29 +1889,10 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: true, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Reachable; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr2; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1780,43 +1907,75 @@ func TestEntryDelayToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } - clock.Advance(c.BaseReachableTime) + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: true, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Reachable { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Reachable) + } + if e.neigh.LinkAddr != entryTestLinkAddr2 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr2) + } + e.mu.Unlock() + clock.Advance(c.BaseReachableTime) wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1837,13 +1996,31 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if e.neigh.State != Delay { t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) } @@ -1860,57 +2037,52 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing } e.mu.Unlock() - wantProbes := []entryTestProbeInfo{ - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) - } - clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -1922,32 +2094,13 @@ func TestEntryDelayToReachableWhenSolicitedConfirmationWithoutAddress(t *testing func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - if got, want := e.neigh.LinkAddr, entryTestLinkAddr1; got != want { - t.Errorf("got e.neigh.LinkAddr = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -1962,27 +2115,56 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + if e.neigh.LinkAddr != entryTestLinkAddr1 { + t.Errorf("got e.neigh.LinkAddr = %q, want = %q", e.neigh.LinkAddr, entryTestLinkAddr1) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, } nudDisp.mu.Lock() @@ -1994,25 +2176,13 @@ func TestEntryStaysDelayWhenOverrideConfirmationWithSameAddress(t *testing.T) { func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleProbeLocked(entryTestLinkAddr2) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -2027,34 +2197,58 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleProbeLocked(entryTestLinkAddr2) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2066,29 +2260,13 @@ func TestEntryDelayToStaleWhenProbeWithDifferentAddress(t *testing.T) { func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { c := DefaultNUDConfigurations() - e, nudDisp, linkRes, _ := entryTestSetup(c) + e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() - e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ - Solicited: false, - Override: false, - IsRouter: false, - }) - e.handlePacketQueuedLocked() - if got, want := e.neigh.State, Delay; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ - Solicited: false, - Override: true, - IsRouter: false, - }) - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() + runImmediatelyScheduledJobs(clock) wantProbes := []entryTestProbeInfo{ { RemoteAddress: entryTestAddr1, @@ -2103,34 +2281,62 @@ func TestEntryDelayToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) } + e.mu.Lock() + e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ + Solicited: false, + Override: false, + IsRouter: false, + }) + e.handlePacketQueuedLocked(entryTestAddr2) + if e.neigh.State != Delay { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Delay) + } + e.handleConfirmationLocked(entryTestLinkAddr2, ReachabilityConfirmationFlags{ + Solicited: false, + Override: true, + IsRouter: false, + }) + if e.neigh.State != Stale { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Stale) + } + e.mu.Unlock() + wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2145,69 +2351,91 @@ func TestEntryDelayToProbe(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) if got, want := e.neigh.State, Delay; got != want { t.Errorf("got e.neigh.State = %q, want = %q", got, want) } e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -2228,36 +2456,50 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2274,37 +2516,47 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2312,12 +2564,6 @@ func TestEntryProbeToStaleWhenProbeWithDifferentAddress(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { @@ -2325,36 +2571,50 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2375,37 +2635,47 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2413,12 +2683,6 @@ func TestEntryProbeToStaleWhenConfirmationWithDifferentAddress(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Stale; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { @@ -2426,36 +2690,51 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + // The second probe is caused by the Delay-to-Probe transition. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2479,30 +2758,38 @@ func TestEntryStaysProbeWhenOverrideConfirmationWithSameAddress(t *testing.T) { { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -2529,17 +2816,14 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { e.mu.Lock() e.handleProbeLocked(entryTestLinkAddr1) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - wantProbes := []entryTestProbeInfo{ - // Probe caused by the Delay-to-Probe transition { RemoteAddress: entryTestAddr1, RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, }, } linkRes.mu.Lock() @@ -2567,42 +2851,51 @@ func TestEntryUnknownToStaleToProbeToReachable(t *testing.T) { e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2622,36 +2915,50 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2672,49 +2979,60 @@ func TestEntryProbeToReachableWhenSolicitedOverrideConfirmation(t *testing.T) { e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr2, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr2, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2734,36 +3052,50 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2781,49 +3113,60 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithSameAddress(t *testin e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2843,36 +3186,50 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: tcpip.LinkAddress(""), + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() clock.Advance(c.DelayFirstProbeTime) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The second probe is caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } e.mu.Lock() @@ -2890,49 +3247,60 @@ func TestEntryProbeToReachableWhenSolicitedConfirmationWithoutAddress(t *testing e.mu.Unlock() clock.Advance(c.BaseReachableTime) - wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Reachable, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Reachable, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, } nudDisp.mu.Lock() @@ -2946,87 +3314,116 @@ func TestEntryProbeToFailed(t *testing.T) { c := DefaultNUDConfigurations() c.MaxMulticastProbes = 3 c.MaxUnicastProbes = 3 + c.DelayFirstProbeTime = c.RetransmitTimer e, nudDisp, linkRes, clock := entryTestSetup(c) e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() - waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) - clock.Advance(waitFor) + // Observe each probe sent while in the Probe state. + for i := uint32(0); i < c.MaxUnicastProbes; i++ { + clock.Advance(c.RetransmitTimer) + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probe #%d mismatch (-got, +want):\n%s", i+1, diff) + } - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The next three probe are caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, + e.mu.Lock() + if e.neigh.State != Probe { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Probe) + } + e.mu.Unlock() } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + + // Wait for the last probe to expire, causing a transition to Failed. + clock.Advance(c.RetransmitTimer) + e.mu.Lock() + if e.neigh.State != Failed { + t.Errorf("got e.neigh.State = %q, want = %q", e.neigh.State, Failed) } + e.mu.Unlock() wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() @@ -3034,12 +3431,6 @@ func TestEntryProbeToFailed(t *testing.T) { t.Errorf("nud dispatcher events mismatch (-got, +want):\n%s", diff) } nudDisp.mu.Unlock() - - e.mu.Lock() - if got, want := e.neigh.State, Failed; got != want { - t.Errorf("got e.neigh.State = %q, want = %q", got, want) - } - e.mu.Unlock() } func TestEntryFailedGetsDeleted(t *testing.T) { @@ -3054,84 +3445,106 @@ func TestEntryFailedGetsDeleted(t *testing.T) { } e.mu.Lock() - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) + e.mu.Unlock() + + runImmediatelyScheduledJobs(clock) + { + wantProbes := []entryTestProbeInfo{ + { + RemoteAddress: entryTestAddr1, + LocalAddress: entryTestAddr2, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.probes = nil + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } + } + + e.mu.Lock() e.handleConfirmationLocked(entryTestLinkAddr1, ReachabilityConfirmationFlags{ Solicited: false, Override: false, IsRouter: false, }) - e.handlePacketQueuedLocked() + e.handlePacketQueuedLocked(entryTestAddr2) e.mu.Unlock() waitFor := c.DelayFirstProbeTime + c.RetransmitTimer*time.Duration(c.MaxUnicastProbes) + c.UnreachableTime clock.Advance(waitFor) - - wantProbes := []entryTestProbeInfo{ - // The first probe is caused by the Unknown-to-Incomplete transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: tcpip.LinkAddress(""), - LocalAddress: entryTestAddr2, - }, - // The next three probe are caused by the Delay-to-Probe transition. - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - { - RemoteAddress: entryTestAddr1, - RemoteLinkAddress: entryTestLinkAddr1, - LocalAddress: entryTestAddr2, - }, - } - linkRes.mu.Lock() - diff := cmp.Diff(linkRes.probes, wantProbes) - linkRes.mu.Unlock() - if diff != "" { - t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + { + wantProbes := []entryTestProbeInfo{ + // The next three probe are sent in Probe. + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + { + RemoteAddress: entryTestAddr1, + RemoteLinkAddress: entryTestLinkAddr1, + }, + } + linkRes.mu.Lock() + diff := cmp.Diff(linkRes.probes, wantProbes) + linkRes.mu.Unlock() + if diff != "" { + t.Fatalf("link address resolver probes mismatch (-got, +want):\n%s", diff) + } } wantEvents := []testEntryEventInfo{ { EventType: entryTestAdded, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: tcpip.LinkAddress(""), - State: Incomplete, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: tcpip.LinkAddress(""), + State: Incomplete, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Stale, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Stale, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Delay, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Delay, + }, }, { EventType: entryTestChanged, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, { EventType: entryTestRemoved, NICID: entryTestNICID, - Addr: entryTestAddr1, - LinkAddr: entryTestLinkAddr1, - State: Probe, + Entry: NeighborEntry{ + Addr: entryTestAddr1, + LinkAddr: entryTestLinkAddr1, + State: Probe, + }, }, } nudDisp.mu.Lock() diff --git a/pkg/tcpip/stack/nic.go b/pkg/tcpip/stack/nic.go index dcd4319bf..60c81a3aa 100644 --- a/pkg/tcpip/stack/nic.go +++ b/pkg/tcpip/stack/nic.go @@ -273,6 +273,15 @@ func (n *NIC) WritePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumb return n.writePacket(r, gso, protocol, pkt) } +// WritePacketToRemote implements NetworkInterface. +func (n *NIC) WritePacketToRemote(remoteLinkAddr tcpip.LinkAddress, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { + r := Route{ + NetProto: protocol, + RemoteLinkAddress: remoteLinkAddr, + } + return n.writePacket(&r, gso, protocol, pkt) +} + func (n *NIC) writePacket(r *Route, gso *GSO, protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) *tcpip.Error { // WritePacket takes ownership of pkt, calculate numBytes first. numBytes := pkt.Size() @@ -339,6 +348,16 @@ func (n *NIC) getAddress(protocol tcpip.NetworkProtocolNumber, dst tcpip.Address return n.getAddressOrCreateTemp(protocol, dst, CanBePrimaryEndpoint, promiscuous) } +func (n *NIC) hasAddress(protocol tcpip.NetworkProtocolNumber, addr tcpip.Address) bool { + ep := n.getAddressOrCreateTempInner(protocol, addr, false, NeverPrimaryEndpoint) + if ep != nil { + ep.DecRef() + return true + } + + return false +} + // findEndpoint finds the endpoint, if any, with the given address. func (n *NIC) findEndpoint(protocol tcpip.NetworkProtocolNumber, address tcpip.Address, peb PrimaryEndpointBehavior) AssignableAddressEndpoint { return n.getAddressOrCreateTemp(protocol, address, peb, spoofing) @@ -546,10 +565,10 @@ func (n *NIC) isInGroup(addr tcpip.Address) bool { } func (n *NIC) handlePacket(protocol tcpip.NetworkProtocolNumber, dst, src tcpip.Address, remotelinkAddr tcpip.LinkAddress, addressEndpoint AssignableAddressEndpoint, pkt *PacketBuffer) { - r := makeRoute(protocol, dst, src, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) + r := makeRoute(protocol, dst, src, n, n, addressEndpoint, false /* handleLocal */, false /* multicastLoop */) defer r.Release() - r.RemoteLinkAddress = remotelinkAddr - n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) + r.PopulatePacketInfo(pkt) + n.getNetworkEndpoint(protocol).HandlePacket(pkt) } // DeliverNetworkPacket finds the appropriate network protocol endpoint and @@ -585,6 +604,7 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp if local == "" { local = n.LinkEndpoint.LinkAddress() } + pkt.RXTransportChecksumValidated = n.LinkEndpoint.Capabilities()&CapabilityRXChecksumOffload != 0 // Are any packet type sockets listening for this network protocol? packetEPs := n.mu.packetEPs[protocol] @@ -660,14 +680,13 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp } // Found a NIC. - n := r.nic + n := r.localAddressNIC if addressEndpoint := n.getAddressOrCreateTempInner(protocol, dst, false, NeverPrimaryEndpoint); addressEndpoint != nil { if n.isValidForOutgoing(addressEndpoint) { - r.LocalLinkAddress = n.LinkEndpoint.LinkAddress() - r.RemoteLinkAddress = remote + pkt.NICID = n.ID() r.RemoteAddress = src - // TODO(b/123449044): Update the source NIC as well. - n.getNetworkEndpoint(protocol).HandlePacket(&r, pkt) + pkt.NetworkPacketInfo = r.networkPacketInfo() + n.getNetworkEndpoint(protocol).HandlePacket(pkt) addressEndpoint.DecRef() r.Release() return @@ -678,7 +697,8 @@ func (n *NIC) DeliverNetworkPacket(remote, local tcpip.LinkAddress, protocol tcp // n doesn't have a destination endpoint. // Send the packet out of n. - // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease the TTL field for ipv4/ipv6. + // TODO(gvisor.dev/issue/1085): According to the RFC, we must decrease + // the TTL field for ipv4/ipv6. // pkt may have set its header and may not have enough headroom for // link-layer header for the other link to prepend. Here we create a new @@ -725,7 +745,7 @@ func (n *NIC) DeliverOutboundPacket(remote, local tcpip.LinkAddress, protocol tc // DeliverTransportPacket delivers the packets to the appropriate transport // protocol endpoint. -func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { +func (n *NIC) DeliverTransportPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition { state, ok := n.stack.transportProtocols[protocol] if !ok { n.stack.stats.UnknownProtocolRcvdPackets.Increment() @@ -737,7 +757,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // Raw socket packets are delivered based solely on the transport // protocol number. We do not inspect the payload to ensure it's // validly formed. - n.stack.demux.deliverRawPacket(r, protocol, pkt) + n.stack.demux.deliverRawPacket(protocol, pkt) // TransportHeader is empty only when pkt is an ICMP packet or was reassembled // from fragments. @@ -766,14 +786,25 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN return TransportPacketHandled } - id := TransportEndpointID{dstPort, r.LocalAddress, srcPort, r.RemoteAddress} - if n.stack.demux.deliverPacket(r, protocol, pkt, id) { + netProto, ok := n.stack.networkProtocols[pkt.NetworkProtocolNumber] + if !ok { + panic(fmt.Sprintf("expected network protocol = %d, have = %#v", pkt.NetworkProtocolNumber, n.stack.networkProtocolNumbers())) + } + + src, dst := netProto.ParseAddresses(pkt.NetworkHeader().View()) + id := TransportEndpointID{ + LocalPort: dstPort, + LocalAddress: dst, + RemotePort: srcPort, + RemoteAddress: src, + } + if n.stack.demux.deliverPacket(protocol, pkt, id) { return TransportPacketHandled } // Try to deliver to per-stack default handler. if state.defaultHandler != nil { - if state.defaultHandler(r, id, pkt) { + if state.defaultHandler(id, pkt) { return TransportPacketHandled } } @@ -781,7 +812,7 @@ func (n *NIC) DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolN // We could not find an appropriate destination for this packet so // give the protocol specific error handler a chance to handle it. // If it doesn't handle it then we should do so. - switch res := transProto.HandleUnknownDestinationPacket(r, id, pkt); res { + switch res := transProto.HandleUnknownDestinationPacket(id, pkt); res { case UnknownDestinationPacketMalformed: n.stack.stats.MalformedRcvdPackets.Increment() return TransportPacketHandled @@ -885,7 +916,7 @@ func (n *NIC) unregisterPacketEndpoint(netProto tcpip.NetworkProtocolNumber, ep } // isValidForOutgoing returns true if the endpoint can be used to send out a -// packet. It requires the endpoint to not be marked expired (i.e., its address) +// packet. It requires the endpoint to not be marked expired (i.e., its address // has been removed) unless the NIC is in spoofing mode, or temporary. func (n *NIC) isValidForOutgoing(ep AssignableAddressEndpoint) bool { n.mu.RLock() diff --git a/pkg/tcpip/stack/nic_test.go b/pkg/tcpip/stack/nic_test.go index 97a96af62..5b5c58afb 100644 --- a/pkg/tcpip/stack/nic_test.go +++ b/pkg/tcpip/stack/nic_test.go @@ -83,8 +83,7 @@ func (*testIPv6Endpoint) WriteHeaderIncludedPacket(*Route, *PacketBuffer) *tcpip } // HandlePacket implements NetworkEndpoint.HandlePacket. -func (*testIPv6Endpoint) HandlePacket(*Route, *PacketBuffer) { -} +func (*testIPv6Endpoint) HandlePacket(*PacketBuffer) {} // Close implements NetworkEndpoint.Close. func (e *testIPv6Endpoint) Close() { @@ -169,7 +168,7 @@ func (*testIPv6Protocol) LinkAddressProtocol() tcpip.NetworkProtocolNumber { } // LinkAddressRequest implements LinkAddressResolver. -func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ LinkEndpoint) *tcpip.Error { +func (*testIPv6Protocol) LinkAddressRequest(_, _ tcpip.Address, _ tcpip.LinkAddress, _ NetworkInterface) *tcpip.Error { return nil } diff --git a/pkg/tcpip/stack/nud.go b/pkg/tcpip/stack/nud.go index e1ec15487..ab629b3a4 100644 --- a/pkg/tcpip/stack/nud.go +++ b/pkg/tcpip/stack/nud.go @@ -129,7 +129,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborAdded(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborAdded(tcpip.NICID, NeighborEntry) // OnNeighborChanged will be called when an entry in a NIC's (with ID nicID) // neighbor table changes state and/or link address. @@ -138,7 +138,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborChanged(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborChanged(tcpip.NICID, NeighborEntry) // OnNeighborRemoved will be called when an entry is removed from a NIC's // (with ID nicID) neighbor table. @@ -147,7 +147,7 @@ type NUDDispatcher interface { // the stack's operation. // // May be called concurrently. - OnNeighborRemoved(nicID tcpip.NICID, ipAddr tcpip.Address, linkAddr tcpip.LinkAddress, state NeighborState, updatedAt time.Time) + OnNeighborRemoved(tcpip.NICID, NeighborEntry) } // ReachabilityConfirmationFlags describes the flags used within a reachability @@ -177,7 +177,7 @@ type NUDHandler interface { // Neighbor Solicitation for ARP or NDP, respectively). Validation of the // probe needs to be performed before calling this function since the // Neighbor Cache doesn't have access to view the NIC's assigned addresses. - HandleProbe(remoteAddr, localAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) + HandleProbe(remoteAddr tcpip.Address, protocol tcpip.NetworkProtocolNumber, remoteLinkAddr tcpip.LinkAddress, linkRes LinkAddressResolver) // HandleConfirmation processes an incoming neighbor confirmation (e.g. ARP // reply or Neighbor Advertisement for ARP or NDP, respectively). diff --git a/pkg/tcpip/stack/packet_buffer.go b/pkg/tcpip/stack/packet_buffer.go index 7f54a6de8..664cc6fa0 100644 --- a/pkg/tcpip/stack/packet_buffer.go +++ b/pkg/tcpip/stack/packet_buffer.go @@ -112,6 +112,16 @@ type PacketBuffer struct { // PktType indicates the SockAddrLink.PacketType of the packet as defined in // https://www.man7.org/linux/man-pages/man7/packet.7.html. PktType tcpip.PacketType + + // NICID is the ID of the interface the network packet was received at. + NICID tcpip.NICID + + // RXTransportChecksumValidated indicates that transport checksum verification + // may be safely skipped. + RXTransportChecksumValidated bool + + // NetworkPacketInfo holds an incoming packet's network-layer information. + NetworkPacketInfo NetworkPacketInfo } // NewPacketBuffer creates a new PacketBuffer with opts. @@ -240,20 +250,33 @@ func (pk *PacketBuffer) consume(typ headerType, size int) (v buffer.View, consum // Clone should be called in such cases so that no modifications is done to // underlying packet payload. func (pk *PacketBuffer) Clone() *PacketBuffer { - newPk := &PacketBuffer{ - PacketBufferEntry: pk.PacketBufferEntry, - Data: pk.Data.Clone(nil), - headers: pk.headers, - header: pk.header, - Hash: pk.Hash, - Owner: pk.Owner, - EgressRoute: pk.EgressRoute, - GSOOptions: pk.GSOOptions, - NetworkProtocolNumber: pk.NetworkProtocolNumber, - NatDone: pk.NatDone, - TransportProtocolNumber: pk.TransportProtocolNumber, + return &PacketBuffer{ + PacketBufferEntry: pk.PacketBufferEntry, + Data: pk.Data.Clone(nil), + headers: pk.headers, + header: pk.header, + Hash: pk.Hash, + Owner: pk.Owner, + GSOOptions: pk.GSOOptions, + NetworkProtocolNumber: pk.NetworkProtocolNumber, + NatDone: pk.NatDone, + TransportProtocolNumber: pk.TransportProtocolNumber, + PktType: pk.PktType, + NICID: pk.NICID, + RXTransportChecksumValidated: pk.RXTransportChecksumValidated, + NetworkPacketInfo: pk.NetworkPacketInfo, } - return newPk +} + +// SourceLinkAddress returns the source link address of the packet. +func (pk *PacketBuffer) SourceLinkAddress() tcpip.LinkAddress { + link := pk.LinkHeader().View() + + if link.IsEmpty() { + return "" + } + + return header.Ethernet(link).SourceAddress() } // Network returns the network header as a header.Network. @@ -270,6 +293,17 @@ func (pk *PacketBuffer) Network() header.Network { } } +// CloneToInbound makes a shallow copy of the packet buffer to be used as an +// inbound packet. +// +// See PacketBuffer.Data for details about how a packet buffer holds an inbound +// packet. +func (pk *PacketBuffer) CloneToInbound() *PacketBuffer { + return NewPacketBuffer(PacketBufferOptions{ + Data: buffer.NewVectorisedView(pk.Size(), pk.Views()), + }) +} + // headerInfo stores metadata about a header in a packet. type headerInfo struct { // buf is the memorized slice for both prepended and consumed header. diff --git a/pkg/tcpip/stack/pending_packets.go b/pkg/tcpip/stack/pending_packets.go index f838eda8d..5d364a2b0 100644 --- a/pkg/tcpip/stack/pending_packets.go +++ b/pkg/tcpip/stack/pending_packets.go @@ -106,7 +106,7 @@ func (f *packetsPendingLinkResolution) enqueue(ch <-chan struct{}, r *Route, pro } else if _, err := p.route.Resolve(nil); err != nil { p.route.Stats().IP.OutgoingPacketErrors.Increment() } else { - p.route.nic.writePacket(p.route, nil /* gso */, p.proto, p.pkt) + p.route.outgoingNIC.writePacket(p.route, nil /* gso */, p.proto, p.pkt) } p.route.Release() } diff --git a/pkg/tcpip/stack/registration.go b/pkg/tcpip/stack/registration.go index defb9129b..b8f333057 100644 --- a/pkg/tcpip/stack/registration.go +++ b/pkg/tcpip/stack/registration.go @@ -63,17 +63,28 @@ const ( ControlUnknown ) +// NetworkPacketInfo holds information about a network layer packet. +type NetworkPacketInfo struct { + // RemoteAddressBroadcast is true if the packet's remote address is a + // broadcast address. + RemoteAddressBroadcast bool + + // LocalAddressBroadcast is true if the packet's local address is a broadcast + // address. + LocalAddressBroadcast bool +} + // TransportEndpoint is the interface that needs to be implemented by transport // protocol (e.g., tcp, udp) endpoints that can handle packets. type TransportEndpoint interface { // UniqueID returns an unique ID for this transport endpoint. UniqueID() uint64 - // HandlePacket is called by the stack when new packets arrive to - // this transport endpoint. It sets pkt.TransportHeader. + // HandlePacket is called by the stack when new packets arrive to this + // transport endpoint. It sets the packet buffer's transport header. // - // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) + // HandlePacket takes ownership of the packet. + HandlePacket(TransportEndpointID, *PacketBuffer) // HandleControlPacket is called by the stack when new control (e.g. // ICMP) packets arrive to this transport endpoint. @@ -105,8 +116,8 @@ type RawTransportEndpoint interface { // this transport endpoint. The packet contains all data from the link // layer up. // - // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt *PacketBuffer) + // HandlePacket takes ownership of the packet. + HandlePacket(*PacketBuffer) } // PacketEndpoint is the interface that needs to be implemented by packet @@ -172,9 +183,9 @@ type TransportProtocol interface { // protocol that don't match any existing endpoint. For example, // it is targeted at a port that has no listeners. // - // HandleUnknownDestinationPacket takes ownership of pkt if it handles + // HandleUnknownDestinationPacket takes ownership of the packet if it handles // the issue. - HandleUnknownDestinationPacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) UnknownDestinationPacketDisposition + HandleUnknownDestinationPacket(TransportEndpointID, *PacketBuffer) UnknownDestinationPacketDisposition // SetOption allows enabling/disabling protocol specific features. // SetOption returns an error if the option is not supported or the @@ -227,8 +238,8 @@ type TransportDispatcher interface { // // pkt.NetworkHeader must be set before calling DeliverTransportPacket. // - // DeliverTransportPacket takes ownership of pkt. - DeliverTransportPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) TransportPacketDisposition + // DeliverTransportPacket takes ownership of the packet. + DeliverTransportPacket(tcpip.TransportProtocolNumber, *PacketBuffer) TransportPacketDisposition // DeliverTransportControlPacket delivers control packets to the // appropriate transport protocol endpoint. @@ -329,6 +340,9 @@ type AssignableAddressEndpoint interface { // AddressWithPrefix returns the endpoint's address. AddressWithPrefix() tcpip.AddressWithPrefix + // Subnet returns the subnet of the endpoint's address. + Subnet() tcpip.Subnet + // IsAssigned returns whether or not the endpoint is considered bound // to its NetworkEndpoint. IsAssigned(allowExpired bool) bool @@ -490,6 +504,9 @@ type NetworkInterface interface { // Enabled returns true if the interface is enabled. Enabled() bool + + // WritePacketToRemote writes the packet to the given remote link address. + WritePacketToRemote(tcpip.LinkAddress, *GSO, tcpip.NetworkProtocolNumber, *PacketBuffer) *tcpip.Error } // NetworkEndpoint is the interface that needs to be implemented by endpoints @@ -544,7 +561,7 @@ type NetworkEndpoint interface { // this network endpoint. It sets pkt.NetworkHeader. // // HandlePacket takes ownership of pkt. - HandlePacket(r *Route, pkt *PacketBuffer) + HandlePacket(pkt *PacketBuffer) // Close is called when the endpoint is reomved from a stack. Close() @@ -764,13 +781,13 @@ type InjectableLinkEndpoint interface { // A LinkAddressResolver is an extension to a NetworkProtocol that // can resolve link addresses. type LinkAddressResolver interface { - // LinkAddressRequest sends a request for the LinkAddress of addr. Broadcasts - // the request on the local network if remoteLinkAddr is the zero value. The - // request is sent on linkEP with localAddr as the source. + // LinkAddressRequest sends a request for the link address of the target + // address. The request is broadcasted on the local network if a remote link + // address is not provided. // - // A valid response will cause the discovery protocol's network - // endpoint to call AddLinkAddress. - LinkAddressRequest(addr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, linkEP LinkEndpoint) *tcpip.Error + // The request is sent from the passed network interface. If the interface + // local address is unspecified, any interface local address may be used. + LinkAddressRequest(targetAddr, localAddr tcpip.Address, remoteLinkAddr tcpip.LinkAddress, nic NetworkInterface) *tcpip.Error // ResolveStaticAddress attempts to resolve address without sending // requests. It either resolves the name immediately or returns the diff --git a/pkg/tcpip/stack/route.go b/pkg/tcpip/stack/route.go index b76e2d37b..15ff437c7 100644 --- a/pkg/tcpip/stack/route.go +++ b/pkg/tcpip/stack/route.go @@ -15,6 +15,8 @@ package stack import ( + "fmt" + "gvisor.dev/gvisor/pkg/sleep" "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/header" @@ -45,11 +47,16 @@ type Route struct { // Loop controls where WritePacket should send packets. Loop PacketLooping - // nic is the NIC the route goes through. - nic *NIC + // localAddressNIC is the interface the address is associated with. + // TODO(gvisor.dev/issue/4548): Remove this field once we can query the + // address's assigned status without the NIC. + localAddressNIC *NIC + + // localAddressEndpoint is the local address this route is associated with. + localAddressEndpoint AssignableAddressEndpoint - // addressEndpoint is the local address this route is associated with. - addressEndpoint AssignableAddressEndpoint + // outgoingNIC is the interface this route uses to write packets. + outgoingNIC *NIC // linkCache is set if link address resolution is enabled for this protocol on // the route's NIC. @@ -60,51 +67,144 @@ type Route struct { linkRes LinkAddressResolver } +// constructAndValidateRoute validates and initializes a route. It takes +// ownership of the provided local address. +// +// Returns an empty route if validation fails. +func constructAndValidateRoute(netProto tcpip.NetworkProtocolNumber, addressEndpoint AssignableAddressEndpoint, localAddressNIC, outgoingNIC *NIC, gateway, remoteAddr tcpip.Address, handleLocal, multicastLoop bool) Route { + addrWithPrefix := addressEndpoint.AddressWithPrefix() + + if localAddressNIC != outgoingNIC && header.IsV6LinkLocalAddress(addrWithPrefix.Address) { + addressEndpoint.DecRef() + return Route{} + } + + // If no remote address is provided, use the local address. + if len(remoteAddr) == 0 { + remoteAddr = addrWithPrefix.Address + } + + r := makeRoute( + netProto, + addrWithPrefix.Address, + remoteAddr, + outgoingNIC, + localAddressNIC, + addressEndpoint, + handleLocal, + multicastLoop, + ) + + // If the route requires us to send a packet through some gateway, do not + // broadcast it. + if len(gateway) > 0 { + r.NextHop = gateway + } else if subnet := addrWithPrefix.Subnet(); subnet.IsBroadcast(remoteAddr) { + r.RemoteLinkAddress = header.EthernetBroadcastAddress + } + + return r +} + // makeRoute initializes a new route. It takes ownership of the provided // AssignableAddressEndpoint. -func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, nic *NIC, addressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { +func makeRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, handleLocal, multicastLoop bool) Route { + if localAddressNIC.stack != outgoingNIC.stack { + panic(fmt.Sprintf("cannot create a route with NICs from different stacks")) + } + loop := PacketOut - if handleLocal && localAddr != "" && remoteAddr == localAddr { - loop = PacketLoop - } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { - loop |= PacketLoop - } else if remoteAddr == header.IPv4Broadcast { - loop |= PacketLoop + + // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the + // link endpoint level. We can remove this check once loopback interfaces + // loop back packets at the network layer. + if !outgoingNIC.IsLoopback() { + if handleLocal && localAddr != "" && remoteAddr == localAddr { + loop = PacketLoop + } else if multicastLoop && (header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr)) { + loop |= PacketLoop + } else if remoteAddr == header.IPv4Broadcast { + loop |= PacketLoop + } else if subnet := localAddressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { + loop |= PacketLoop + } } + return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) +} + +func makeRouteInner(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint, loop PacketLooping) Route { r := Route{ - NetProto: netProto, - LocalAddress: localAddr, - LocalLinkAddress: nic.LinkEndpoint.LinkAddress(), - RemoteAddress: remoteAddr, - addressEndpoint: addressEndpoint, - nic: nic, - Loop: loop, + NetProto: netProto, + LocalAddress: localAddr, + LocalLinkAddress: outgoingNIC.LinkEndpoint.LinkAddress(), + RemoteAddress: remoteAddr, + localAddressNIC: localAddressNIC, + localAddressEndpoint: localAddressEndpoint, + outgoingNIC: outgoingNIC, + Loop: loop, } - if r.nic.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { - if linkRes, ok := r.nic.stack.linkAddrResolvers[r.NetProto]; ok { + if r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityResolutionRequired != 0 { + if linkRes, ok := r.outgoingNIC.stack.linkAddrResolvers[r.NetProto]; ok { r.linkRes = linkRes - r.linkCache = r.nic.stack + r.linkCache = r.outgoingNIC.stack } } return r } +// makeLocalRoute initializes a new local route. It takes ownership of the +// provided AssignableAddressEndpoint. +// +// A local route is a route to a destination that is local to the stack. +func makeLocalRoute(netProto tcpip.NetworkProtocolNumber, localAddr, remoteAddr tcpip.Address, outgoingNIC, localAddressNIC *NIC, localAddressEndpoint AssignableAddressEndpoint) Route { + loop := PacketLoop + // TODO(gvisor.dev/issue/4689): Loopback interface loops back packets at the + // link endpoint level. We can remove this check once loopback interfaces + // loop back packets at the network layer. + if outgoingNIC.IsLoopback() { + loop = PacketOut + } + return makeRouteInner(netProto, localAddr, remoteAddr, outgoingNIC, localAddressNIC, localAddressEndpoint, loop) +} + +// PopulatePacketInfo populates a packet buffer's packet information fields. +// +// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by +// the network layer. +func (r *Route) PopulatePacketInfo(pkt *PacketBuffer) { + if r.local() { + pkt.RXTransportChecksumValidated = true + } + pkt.NetworkPacketInfo = r.networkPacketInfo() +} + +// networkPacketInfo returns the network packet information of the route. +// +// TODO(gvisor.dev/issue/4688): Remove this once network packets are handled by +// the network layer. +func (r *Route) networkPacketInfo() NetworkPacketInfo { + return NetworkPacketInfo{ + RemoteAddressBroadcast: r.IsOutboundBroadcast(), + LocalAddressBroadcast: r.isInboundBroadcast(), + } +} + // NICID returns the id of the NIC from which this route originates. func (r *Route) NICID() tcpip.NICID { - return r.nic.ID() + return r.outgoingNIC.ID() } // MaxHeaderLength forwards the call to the network endpoint's implementation. func (r *Route) MaxHeaderLength() uint16 { - return r.nic.getNetworkEndpoint(r.NetProto).MaxHeaderLength() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MaxHeaderLength() } // Stats returns a mutable copy of current stats. func (r *Route) Stats() tcpip.Stats { - return r.nic.stack.Stats() + return r.outgoingNIC.stack.Stats() } // PseudoHeaderChecksum forwards the call to the network endpoint's @@ -113,14 +213,38 @@ func (r *Route) PseudoHeaderChecksum(protocol tcpip.TransportProtocolNumber, tot return header.PseudoHeaderChecksum(protocol, r.LocalAddress, r.RemoteAddress, totalLen) } -// Capabilities returns the link-layer capabilities of the route. -func (r *Route) Capabilities() LinkEndpointCapabilities { - return r.nic.LinkEndpoint.Capabilities() +// RequiresTXTransportChecksum returns false if the route does not require +// transport checksums to be populated. +func (r *Route) RequiresTXTransportChecksum() bool { + if r.local() { + return false + } + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityTXChecksumOffload == 0 +} + +// HasSoftwareGSOCapability returns true if the route supports software GSO. +func (r *Route) HasSoftwareGSOCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySoftwareGSO != 0 +} + +// HasHardwareGSOCapability returns true if the route supports hardware GSO. +func (r *Route) HasHardwareGSOCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityHardwareGSO != 0 +} + +// HasSaveRestoreCapability returns true if the route supports save/restore. +func (r *Route) HasSaveRestoreCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilitySaveRestore != 0 +} + +// HasDisconncetOkCapability returns true if the route supports disconnecting. +func (r *Route) HasDisconncetOkCapability() bool { + return r.outgoingNIC.LinkEndpoint.Capabilities()&CapabilityDisconnectOk != 0 } // GSOMaxSize returns the maximum GSO packet size. func (r *Route) GSOMaxSize() uint32 { - if gso, ok := r.nic.LinkEndpoint.(GSOEndpoint); ok { + if gso, ok := r.outgoingNIC.LinkEndpoint.(GSOEndpoint); ok { return gso.GSOMaxSize() } return 0 @@ -158,8 +282,15 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { nextAddr = r.RemoteAddress } - if neigh := r.nic.neigh; neigh != nil { - entry, ch, err := neigh.entry(nextAddr, r.LocalAddress, r.linkRes, waker) + // If specified, the local address used for link address resolution must be an + // address on the outgoing interface. + var linkAddressResolutionRequestLocalAddr tcpip.Address + if r.localAddressNIC == r.outgoingNIC { + linkAddressResolutionRequestLocalAddr = r.LocalAddress + } + + if neigh := r.outgoingNIC.neigh; neigh != nil { + entry, ch, err := neigh.entry(nextAddr, linkAddressResolutionRequestLocalAddr, r.linkRes, waker) if err != nil { return ch, err } @@ -167,7 +298,7 @@ func (r *Route) Resolve(waker *sleep.Waker) (<-chan struct{}, *tcpip.Error) { return nil, nil } - linkAddr, ch, err := r.linkCache.GetLinkAddress(r.nic.ID(), nextAddr, r.LocalAddress, r.NetProto, waker) + linkAddr, ch, err := r.linkCache.GetLinkAddress(r.outgoingNIC.ID(), nextAddr, linkAddressResolutionRequestLocalAddr, r.NetProto, waker) if err != nil { return ch, err } @@ -182,76 +313,102 @@ func (r *Route) RemoveWaker(waker *sleep.Waker) { nextAddr = r.RemoteAddress } - if neigh := r.nic.neigh; neigh != nil { + if neigh := r.outgoingNIC.neigh; neigh != nil { neigh.removeWaker(nextAddr, waker) return } - r.linkCache.RemoveWaker(r.nic.ID(), nextAddr, waker) + r.linkCache.RemoveWaker(r.outgoingNIC.ID(), nextAddr, waker) +} + +// local returns true if the route is a local route. +func (r *Route) local() bool { + return r.Loop == PacketLoop || r.outgoingNIC.IsLoopback() } // IsResolutionRequired returns true if Resolve() must be called to resolve -// the link address before the this route can be written to. +// the link address before the route can be written to. // -// The NIC r uses must not be locked. +// The NICs the route is associated with must not be locked. func (r *Route) IsResolutionRequired() bool { - if r.nic.neigh != nil { - return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkRes != nil && r.RemoteLinkAddress == "" + if !r.isValidForOutgoing() || r.RemoteLinkAddress != "" || r.local() { + return false } - return r.nic.isValidForOutgoing(r.addressEndpoint) && r.linkCache != nil && r.RemoteLinkAddress == "" + + return (r.outgoingNIC.neigh != nil && r.linkRes != nil) || r.linkCache != nil +} + +func (r *Route) isValidForOutgoing() bool { + if !r.outgoingNIC.Enabled() { + return false + } + + if !r.localAddressNIC.isValidForOutgoing(r.localAddressEndpoint) { + return false + } + + // If the source NIC and outgoing NIC are different, make sure the stack has + // forwarding enabled, or the packet will be handled locally. + if r.outgoingNIC != r.localAddressNIC && !r.outgoingNIC.stack.Forwarding(r.NetProto) && (!r.outgoingNIC.stack.handleLocal || !r.outgoingNIC.hasAddress(r.NetProto, r.RemoteAddress)) { + return false + } + + return true } // WritePacket writes the packet through the given route. func (r *Route) WritePacket(gso *GSO, params NetworkHeaderParams, pkt *PacketBuffer) *tcpip.Error { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePacket(r, gso, params, pkt) } // WritePackets writes a list of n packets through the given route and returns // the number of packets written. func (r *Route) WritePackets(gso *GSO, pkts PacketBufferList, params NetworkHeaderParams) (int, *tcpip.Error) { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return 0, tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WritePackets(r, gso, pkts, params) } // WriteHeaderIncludedPacket writes a packet already containing a network // header through the given route. func (r *Route) WriteHeaderIncludedPacket(pkt *PacketBuffer) *tcpip.Error { - if !r.nic.isValidForOutgoing(r.addressEndpoint) { + if !r.isValidForOutgoing() { return tcpip.ErrInvalidEndpointState } - return r.nic.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).WriteHeaderIncludedPacket(r, pkt) } // DefaultTTL returns the default TTL of the underlying network endpoint. func (r *Route) DefaultTTL() uint8 { - return r.nic.getNetworkEndpoint(r.NetProto).DefaultTTL() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).DefaultTTL() } // MTU returns the MTU of the underlying network endpoint. func (r *Route) MTU() uint32 { - return r.nic.getNetworkEndpoint(r.NetProto).MTU() + return r.outgoingNIC.getNetworkEndpoint(r.NetProto).MTU() } // Release frees all resources associated with the route. func (r *Route) Release() { - if r.addressEndpoint != nil { - r.addressEndpoint.DecRef() - r.addressEndpoint = nil + if r.localAddressEndpoint != nil { + r.localAddressEndpoint.DecRef() + r.localAddressEndpoint = nil } } // Clone clones the route. func (r *Route) Clone() Route { - if r.addressEndpoint != nil { - _ = r.addressEndpoint.IncRef() + if r.localAddressEndpoint != nil { + if !r.localAddressEndpoint.IncRef() { + panic(fmt.Sprintf("failed to increment reference count for local address endpoint = %s", r.LocalAddress)) + } } return *r } @@ -275,7 +432,7 @@ func (r *Route) MakeLoopedRoute() Route { // Stack returns the instance of the Stack that owns this route. func (r *Route) Stack() *Stack { - return r.nic.stack + return r.outgoingNIC.stack } func (r *Route) isV4Broadcast(addr tcpip.Address) bool { @@ -283,7 +440,7 @@ func (r *Route) isV4Broadcast(addr tcpip.Address) bool { return true } - subnet := r.addressEndpoint.AddressWithPrefix().Subnet() + subnet := r.localAddressEndpoint.Subnet() return subnet.IsBroadcast(addr) } @@ -294,9 +451,9 @@ func (r *Route) IsOutboundBroadcast() bool { return r.isV4Broadcast(r.RemoteAddress) } -// IsInboundBroadcast returns true if the route is for an inbound broadcast +// isInboundBroadcast returns true if the route is for an inbound broadcast // packet. -func (r *Route) IsInboundBroadcast() bool { +func (r *Route) isInboundBroadcast() bool { // Only IPv4 has a notion of broadcast. return r.isV4Broadcast(r.LocalAddress) } @@ -304,15 +461,16 @@ func (r *Route) IsInboundBroadcast() bool { // ReverseRoute returns new route with given source and destination address. func (r *Route) ReverseRoute(src tcpip.Address, dst tcpip.Address) Route { return Route{ - NetProto: r.NetProto, - LocalAddress: dst, - LocalLinkAddress: r.RemoteLinkAddress, - RemoteAddress: src, - RemoteLinkAddress: r.LocalLinkAddress, - Loop: r.Loop, - addressEndpoint: r.addressEndpoint, - nic: r.nic, - linkCache: r.linkCache, - linkRes: r.linkRes, + NetProto: r.NetProto, + LocalAddress: dst, + LocalLinkAddress: r.RemoteLinkAddress, + RemoteAddress: src, + RemoteLinkAddress: r.LocalLinkAddress, + Loop: r.Loop, + localAddressNIC: r.localAddressNIC, + localAddressEndpoint: r.localAddressEndpoint, + outgoingNIC: r.outgoingNIC, + linkCache: r.linkCache, + linkRes: r.linkRes, } } diff --git a/pkg/tcpip/stack/stack.go b/pkg/tcpip/stack/stack.go index 3a07577c8..a23fb97ff 100644 --- a/pkg/tcpip/stack/stack.go +++ b/pkg/tcpip/stack/stack.go @@ -22,6 +22,7 @@ package stack import ( "bytes" "encoding/binary" + "fmt" mathrand "math/rand" "sync/atomic" "time" @@ -52,7 +53,7 @@ const ( type transportProtocolState struct { proto TransportProtocol - defaultHandler func(r *Route, id TransportEndpointID, pkt *PacketBuffer) bool + defaultHandler func(id TransportEndpointID, pkt *PacketBuffer) bool } // TCPProbeFunc is the expected function type for a TCP probe function to be @@ -518,6 +519,10 @@ type Options struct { // // RandSource must be thread-safe. RandSource mathrand.Source + + // IPTables are the initial iptables rules. If nil, iptables will allow + // all traffic. + IPTables *IPTables } // TransportEndpointInfo holds useful information about a transport endpoint @@ -620,6 +625,10 @@ func New(opts Options) *Stack { randSrc = &lockedRandomSource{src: mathrand.NewSource(generateRandInt64())} } + if opts.IPTables == nil { + opts.IPTables = DefaultTables() + } + opts.NUDConfigs.resetInvalidFields() s := &Stack{ @@ -633,7 +642,7 @@ func New(opts Options) *Stack { clock: clock, stats: opts.Stats.FillIn(), handleLocal: opts.HandleLocal, - tables: DefaultTables(), + tables: opts.IPTables, icmpRateLimiter: NewICMPRateLimiter(), seed: generateRandUint32(), nudConfigs: opts.NUDConfigs, @@ -751,7 +760,7 @@ func (s *Stack) TransportProtocolOption(transport tcpip.TransportProtocolNumber, // // It must be called only during initialization of the stack. Changing it as the // stack is operating is not supported. -func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(*Route, TransportEndpointID, *PacketBuffer) bool) { +func (s *Stack) SetTransportProtocolHandler(p tcpip.TransportProtocolNumber, h func(TransportEndpointID, *PacketBuffer) bool) { state := s.transportProtocols[p] if state != nil { state.defaultHandler = h @@ -830,6 +839,20 @@ func (s *Stack) AddRoute(route tcpip.Route) { s.routeTable = append(s.routeTable, route) } +// RemoveRoutes removes matching routes from the route table. +func (s *Stack) RemoveRoutes(match func(tcpip.Route) bool) { + s.mu.Lock() + defer s.mu.Unlock() + + var filteredRoutes []tcpip.Route + for _, route := range s.routeTable { + if !match(route) { + filteredRoutes = append(filteredRoutes, route) + } + } + s.routeTable = filteredRoutes +} + // NewEndpoint creates a new transport layer endpoint of the given protocol. func (s *Stack) NewEndpoint(transport tcpip.TransportProtocolNumber, network tcpip.NetworkProtocolNumber, waiterQueue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { t, ok := s.transportProtocols[transport] @@ -1180,54 +1203,225 @@ func (s *Stack) getAddressEP(nic *NIC, localAddr, remoteAddr tcpip.Address, netP return nic.findEndpoint(netProto, localAddr, CanBePrimaryEndpoint) } +// findLocalRouteFromNICRLocked is like findLocalRouteRLocked but finds a route +// from the specified NIC. +// +// Precondition: s.mu must be read locked. +func (s *Stack) findLocalRouteFromNICRLocked(localAddressNIC *NIC, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { + localAddressEndpoint := localAddressNIC.getAddressOrCreateTempInner(netProto, localAddr, false /* createTemp */, NeverPrimaryEndpoint) + if localAddressEndpoint == nil { + return Route{}, false + } + + var outgoingNIC *NIC + // Prefer a local route to the same interface as the local address. + if localAddressNIC.hasAddress(netProto, remoteAddr) { + outgoingNIC = localAddressNIC + } + + // If the remote address isn't owned by the local address's NIC, check all + // NICs. + if outgoingNIC == nil { + for _, nic := range s.nics { + if nic.hasAddress(netProto, remoteAddr) { + outgoingNIC = nic + break + } + } + } + + // If the remote address is not owned by the stack, we can't return a local + // route. + if outgoingNIC == nil { + localAddressEndpoint.DecRef() + return Route{}, false + } + + r := makeLocalRoute( + netProto, + localAddressEndpoint.AddressWithPrefix().Address, + remoteAddr, + outgoingNIC, + localAddressNIC, + localAddressEndpoint, + ) + + if r.IsOutboundBroadcast() { + r.Release() + return Route{}, false + } + + return r, true +} + +// findLocalRouteRLocked returns a local route. +// +// A local route is a route to some remote address which the stack owns. That +// is, a local route is a route where packets never have to leave the stack. +// +// Precondition: s.mu must be read locked. +func (s *Stack) findLocalRouteRLocked(localAddressNICID tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber) (route Route, ok bool) { + if len(localAddr) == 0 { + localAddr = remoteAddr + } + + if localAddressNICID == 0 { + for _, localAddressNIC := range s.nics { + if r, ok := s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto); ok { + return r, true + } + } + + return Route{}, false + } + + if localAddressNIC, ok := s.nics[localAddressNICID]; ok { + return s.findLocalRouteFromNICRLocked(localAddressNIC, localAddr, remoteAddr, netProto) + } + + return Route{}, false +} + // FindRoute creates a route to the given destination address, leaving through -// the given nic and local address (if provided). +// the given NIC and local address (if provided). +// +// If a NIC is not specified, the returned route will leave through the same +// NIC as the NIC that has the local address assigned when forwarding is +// disabled. If forwarding is enabled and the NIC is unspecified, the route may +// leave through any interface unless the route is link-local. +// +// If no local address is provided, the stack will select a local address. If no +// remote address is provided, the stack wil use a remote address equal to the +// local address. func (s *Stack) FindRoute(id tcpip.NICID, localAddr, remoteAddr tcpip.Address, netProto tcpip.NetworkProtocolNumber, multicastLoop bool) (Route, *tcpip.Error) { s.mu.RLock() defer s.mu.RUnlock() + isLinkLocal := header.IsV6LinkLocalAddress(remoteAddr) || header.IsV6LinkLocalMulticastAddress(remoteAddr) isLocalBroadcast := remoteAddr == header.IPv4Broadcast isMulticast := header.IsV4MulticastAddress(remoteAddr) || header.IsV6MulticastAddress(remoteAddr) - needRoute := !(isLocalBroadcast || isMulticast || header.IsV6LinkLocalAddress(remoteAddr)) + isLoopback := header.IsV4LoopbackAddress(remoteAddr) || header.IsV6LoopbackAddress(remoteAddr) + needRoute := !(isLocalBroadcast || isMulticast || isLinkLocal || isLoopback) + + if s.handleLocal && !isMulticast && !isLocalBroadcast { + if r, ok := s.findLocalRouteRLocked(id, localAddr, remoteAddr, netProto); ok { + return r, nil + } + } + + // If the interface is specified and we do not need a route, return a route + // through the interface if the interface is valid and enabled. if id != 0 && !needRoute { if nic, ok := s.nics[id]; ok && nic.Enabled() { if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - return makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()), nil + return makeRoute( + netProto, + addressEndpoint.AddressWithPrefix().Address, + remoteAddr, + nic, /* outboundNIC */ + nic, /* localAddressNIC*/ + addressEndpoint, + s.handleLocal, + multicastLoop, + ), nil } } - } else { - for _, route := range s.routeTable { - if (id != 0 && id != route.NIC) || (len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr)) { - continue + + if isLoopback { + return Route{}, tcpip.ErrBadLocalAddress + } + return Route{}, tcpip.ErrNetworkUnreachable + } + + canForward := s.Forwarding(netProto) && !header.IsV6LinkLocalAddress(localAddr) && !isLinkLocal + + // Find a route to the remote with the route table. + var chosenRoute tcpip.Route + for _, route := range s.routeTable { + if len(remoteAddr) != 0 && !route.Destination.Contains(remoteAddr) { + continue + } + + nic, ok := s.nics[route.NIC] + if !ok || !nic.Enabled() { + continue + } + + if id == 0 || id == route.NIC { + if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { + var gateway tcpip.Address + if needRoute { + gateway = route.Gateway + } + r := constructAndValidateRoute(netProto, addressEndpoint, nic /* outgoingNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop) + if r == (Route{}) { + panic(fmt.Sprintf("non-forwarding route validation failed with route table entry = %#v, id = %d, localAddr = %s, remoteAddr = %s", route, id, localAddr, remoteAddr)) + } + return r, nil } - if nic, ok := s.nics[route.NIC]; ok && nic.Enabled() { - if addressEndpoint := s.getAddressEP(nic, localAddr, remoteAddr, netProto); addressEndpoint != nil { - if len(remoteAddr) == 0 { - // If no remote address was provided, then the route - // provided will refer to the link local address. - remoteAddr = addressEndpoint.AddressWithPrefix().Address - } + } - r := makeRoute(netProto, addressEndpoint.AddressWithPrefix().Address, remoteAddr, nic, addressEndpoint, s.handleLocal && !nic.IsLoopback(), multicastLoop && !nic.IsLoopback()) - if len(route.Gateway) > 0 { - if needRoute { - r.NextHop = route.Gateway - } - } else if subnet := addressEndpoint.AddressWithPrefix().Subnet(); subnet.IsBroadcast(remoteAddr) { - r.RemoteLinkAddress = header.EthernetBroadcastAddress + // If the stack has forwarding enabled and we haven't found a valid route to + // the remote address yet, keep track of the first valid route. We keep + // iterating because we prefer routes that let us use a local address that + // is assigned to the outgoing interface. There is no requirement to do this + // from any RFC but simply a choice made to better follow a strong host + // model which the netstack follows at the time of writing. + if canForward && chosenRoute == (tcpip.Route{}) { + chosenRoute = route + } + } + + if chosenRoute != (tcpip.Route{}) { + // At this point we know the stack has forwarding enabled since chosenRoute is + // only set when forwarding is enabled. + nic, ok := s.nics[chosenRoute.NIC] + if !ok { + // If the route's NIC was invalid, we should not have chosen the route. + panic(fmt.Sprintf("chosen route must have a valid NIC with ID = %d", chosenRoute.NIC)) + } + + var gateway tcpip.Address + if needRoute { + gateway = chosenRoute.Gateway + } + + // Use the specified NIC to get the local address endpoint. + if id != 0 { + if aNIC, ok := s.nics[id]; ok { + if addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto); addressEndpoint != nil { + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { + return r, nil } + } + } + + return Route{}, tcpip.ErrNoRoute + } + if id == 0 { + // If an interface is not specified, try to find a NIC that holds the local + // address endpoint to construct a route. + for _, aNIC := range s.nics { + addressEndpoint := s.getAddressEP(aNIC, localAddr, remoteAddr, netProto) + if addressEndpoint == nil { + continue + } + + if r := constructAndValidateRoute(netProto, addressEndpoint, aNIC /* localAddressNIC */, nic /* outgoingNIC */, gateway, remoteAddr, s.handleLocal, multicastLoop); r != (Route{}) { return r, nil } } } } - if !needRoute { - return Route{}, tcpip.ErrNetworkUnreachable + if needRoute { + return Route{}, tcpip.ErrNoRoute } - - return Route{}, tcpip.ErrNoRoute + if isLoopback { + return Route{}, tcpip.ErrBadLocalAddress + } + return Route{}, tcpip.ErrNetworkUnreachable } // CheckNetworkProtocol checks if a given network protocol is enabled in the @@ -1323,7 +1517,7 @@ func (s *Stack) GetLinkAddress(nicID tcpip.NICID, addr, localAddr tcpip.Address, fullAddr := tcpip.FullAddress{NIC: nicID, Addr: addr} linkRes := s.linkAddrResolvers[protocol] - return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic.LinkEndpoint, waker) + return s.linkAddrCache.get(fullAddr, linkRes, localAddr, nic, waker) } // Neighbors returns all IP to MAC address associations. @@ -1443,8 +1637,8 @@ func (s *Stack) CompleteTransportEndpointCleanup(ep TransportEndpoint) { // FindTransportEndpoint finds an endpoint that most closely matches the provided // id. If no endpoint is found it returns nil. -func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { - return s.demux.findTransportEndpoint(netProto, transProto, id, r) +func (s *Stack) FindTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint { + return s.demux.findTransportEndpoint(netProto, transProto, id, nicID) } // RegisterRawTransportEndpoint registers the given endpoint with the stack @@ -1896,3 +2090,71 @@ func (s *Stack) FindNICNameFromID(id tcpip.NICID) string { func (s *Stack) NewJob(l sync.Locker, f func()) *tcpip.Job { return tcpip.NewJob(s.clock, l, f) } + +// ParseResult indicates the result of a parsing attempt. +type ParseResult int + +const ( + // ParsedOK indicates that a packet was successfully parsed. + ParsedOK ParseResult = iota + + // UnknownNetworkProtocol indicates that the network protocol is unknown. + UnknownNetworkProtocol + + // NetworkLayerParseError indicates that the network packet was not + // successfully parsed. + NetworkLayerParseError + + // UnknownTransportProtocol indicates that the transport protocol is unknown. + UnknownTransportProtocol + + // TransportLayerParseError indicates that the transport packet was not + // successfully parsed. + TransportLayerParseError +) + +// ParsePacketBuffer parses the provided packet buffer. +func (s *Stack) ParsePacketBuffer(protocol tcpip.NetworkProtocolNumber, pkt *PacketBuffer) ParseResult { + netProto, ok := s.networkProtocols[protocol] + if !ok { + return UnknownNetworkProtocol + } + + transProtoNum, hasTransportHdr, ok := netProto.Parse(pkt) + if !ok { + return NetworkLayerParseError + } + if !hasTransportHdr { + return ParsedOK + } + + // TODO(gvisor.dev/issue/170): ICMP packets don't have their TransportHeader + // fields set yet, parse it here. See icmp/protocol.go:protocol.Parse for a + // full explanation. + if transProtoNum == header.ICMPv4ProtocolNumber || transProtoNum == header.ICMPv6ProtocolNumber { + return ParsedOK + } + + pkt.TransportProtocolNumber = transProtoNum + // Parse the transport header if present. + state, ok := s.transportProtocols[transProtoNum] + if !ok { + return UnknownTransportProtocol + } + + if !state.proto.Parse(pkt) { + return TransportLayerParseError + } + + return ParsedOK +} + +// networkProtocolNumbers returns the network protocol numbers the stack is +// configured with. +func (s *Stack) networkProtocolNumbers() []tcpip.NetworkProtocolNumber { + protos := make([]tcpip.NetworkProtocolNumber, 0, len(s.networkProtocols)) + for p := range s.networkProtocols { + protos = append(protos, p) + } + return protos +} diff --git a/pkg/tcpip/stack/stack_test.go b/pkg/tcpip/stack/stack_test.go index e75f58c64..dedfdd435 100644 --- a/pkg/tcpip/stack/stack_test.go +++ b/pkg/tcpip/stack/stack_test.go @@ -21,6 +21,7 @@ import ( "bytes" "fmt" "math" + "net" "sort" "testing" "time" @@ -108,12 +109,13 @@ func (*fakeNetworkEndpoint) DefaultTTL() uint8 { return 123 } -func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuffer) { +func (f *fakeNetworkEndpoint) HandlePacket(pkt *stack.PacketBuffer) { // Increment the received packet count in the protocol descriptor. - f.proto.packetCount[int(r.LocalAddress[0])%len(f.proto.packetCount)]++ + netHdr := pkt.NetworkHeader().View() + f.proto.packetCount[int(netHdr[dstAddrOffset])%len(f.proto.packetCount)]++ // Handle control packets. - if pkt.NetworkHeader().View()[protocolNumberOffset] == uint8(fakeControlProtocol) { + if netHdr[protocolNumberOffset] == uint8(fakeControlProtocol) { nb, ok := pkt.Data.PullUp(fakeNetHeaderLen) if !ok { return @@ -129,7 +131,7 @@ func (f *fakeNetworkEndpoint) HandlePacket(r *stack.Route, pkt *stack.PacketBuff } // Dispatch the packet to the transport protocol. - f.dispatcher.DeliverTransportPacket(r, tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) + f.dispatcher.DeliverTransportPacket(tcpip.TransportProtocolNumber(pkt.NetworkHeader().View()[protocolNumberOffset]), pkt) } func (f *fakeNetworkEndpoint) MaxHeaderLength() uint16 { @@ -151,12 +153,15 @@ func (f *fakeNetworkEndpoint) WritePacket(r *stack.Route, gso *stack.GSO, params // Add the protocol's header to the packet and send it to the link // endpoint. hdr := pkt.NetworkHeader().Push(fakeNetHeaderLen) + pkt.NetworkProtocolNumber = fakeNetNumber hdr[dstAddrOffset] = r.RemoteAddress[0] hdr[srcAddrOffset] = r.LocalAddress[0] hdr[protocolNumberOffset] = byte(params.Protocol) if r.Loop&stack.PacketLoop != 0 { - f.HandlePacket(r, pkt) + pkt := pkt.Clone() + r.PopulatePacketInfo(pkt) + f.HandlePacket(pkt) } if r.Loop&stack.PacketOut == 0 { return nil @@ -254,6 +259,7 @@ func (*fakeNetworkProtocol) Parse(pkt *stack.PacketBuffer) (tcpip.TransportProto if !ok { return 0, false, false } + pkt.NetworkProtocolNumber = fakeNetNumber return tcpip.TransportProtocolNumber(hdr[protocolNumberOffset]), true, true } @@ -1334,6 +1340,106 @@ func TestPromiscuousMode(t *testing.T) { testFailingRecv(t, fakeNet, localAddrByte, ep, buf) } +// TestExternalSendWithHandleLocal tests that the stack creates a non-local +// route when spoofing or promiscuous mode are enabled. +// +// This test makes sure that packets are transmitted from the stack. +func TestExternalSendWithHandleLocal(t *testing.T) { + const ( + unspecifiedNICID = 0 + nicID = 1 + + localAddr = tcpip.Address("\x01") + dstAddr = tcpip.Address("\x03") + ) + + subnet, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + + tests := []struct { + name string + configureStack func(*testing.T, *stack.Stack) + }{ + { + name: "Default", + configureStack: func(*testing.T, *stack.Stack) {}, + }, + { + name: "Spoofing", + configureStack: func(t *testing.T, s *stack.Stack) { + if err := s.SetSpoofing(nicID, true); err != nil { + t.Fatalf("s.SetSpoofing(%d, true): %s", nicID, err) + } + }, + }, + { + name: "Promiscuous", + configureStack: func(t *testing.T, s *stack.Stack) { + if err := s.SetPromiscuousMode(nicID, true); err != nil { + t.Fatalf("s.SetPromiscuousMode(%d, true): %s", nicID, err) + } + }, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, handleLocal := range []bool{true, false} { + t.Run(fmt.Sprintf("HandleLocal=%t", handleLocal), func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{fakeNetFactory}, + HandleLocal: handleLocal, + }) + + ep := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + if err := s.AddAddress(nicID, fakeNetNumber, localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, fakeNetNumber, localAddr, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: subnet, NIC: nicID}}) + + test.configureStack(t, s) + + r, err := s.FindRoute(unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, false /* multicastLoop */) + if err != nil { + t.Fatalf("s.FindRoute(%d, %s, %s, %d, false): %s", unspecifiedNICID, localAddr, dstAddr, fakeNetNumber, err) + } + defer r.Release() + + if r.LocalAddress != localAddr { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, localAddr) + } + if r.RemoteAddress != dstAddr { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, dstAddr) + } + + if n := ep.Drain(); n != 0 { + t.Fatalf("got ep.Drain() = %d, want = 0", n) + } + if err := r.WritePacket(nil /* gso */, stack.NetworkHeaderParams{ + Protocol: fakeTransNumber, + TTL: 123, + TOS: stack.DefaultTOS, + }, stack.NewPacketBuffer(stack.PacketBufferOptions{ + ReserveHeaderBytes: int(r.MaxHeaderLength()), + Data: buffer.NewView(10).ToVectorisedView(), + })); err != nil { + t.Fatalf("r.WritePacket(nil, _, _): %s", err) + } + if n := ep.Drain(); n != 1 { + t.Fatalf("got ep.Drain() = %d, want = 1", n) + } + }) + } + }) + } +} + func TestSpoofingWithAddress(t *testing.T) { localAddr := tcpip.Address("\x01") nonExistentLocalAddr := tcpip.Address("\x02") @@ -3346,7 +3452,7 @@ func TestOutgoingSubnetBroadcast(t *testing.T) { RemoteAddress: ipv4SubnetBcast, RemoteLinkAddress: header.EthernetBroadcastAddress, NetProto: header.IPv4ProtocolNumber, - Loop: stack.PacketOut, + Loop: stack.PacketOut | stack.PacketLoop, }, }, // Broadcast to a locally attached /31 subnet does not populate the @@ -3672,3 +3778,453 @@ func TestGetMainNICAddressWhenNICDisabled(t *testing.T) { t.Fatalf("got GetMainNICAddress(%d, %d) = %s, want = %s", nicID, fakeNetNumber, gotAddr, protocolAddress.AddressWithPrefix) } } + +// TestAddRoute tests Stack.AddRoute +func TestAddRoute(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{}) + + subnet1, err := tcpip.NewSubnet("\x00", "\x00") + if err != nil { + t.Fatal(err) + } + + subnet2, err := tcpip.NewSubnet("\x01", "\x01") + if err != nil { + t.Fatal(err) + } + + expected := []tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet2, Gateway: "\x00", NIC: 1}, + } + + // Initialize the route table with one route. + s.SetRouteTable([]tcpip.Route{expected[0]}) + + // Add another route. + s.AddRoute(expected[1]) + + rt := s.GetRouteTable() + if got, want := len(rt), len(expected); got != want { + t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) + } + for i, route := range rt { + if got, want := route, expected[i]; got != want { + t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) + } + } +} + +// TestRemoveRoutes tests Stack.RemoveRoutes +func TestRemoveRoutes(t *testing.T) { + const nicID = 1 + + s := stack.New(stack.Options{}) + + addressToRemove := tcpip.Address("\x01") + subnet1, err := tcpip.NewSubnet(addressToRemove, "\x01") + if err != nil { + t.Fatal(err) + } + + subnet2, err := tcpip.NewSubnet(addressToRemove, "\x01") + if err != nil { + t.Fatal(err) + } + + subnet3, err := tcpip.NewSubnet("\x02", "\x02") + if err != nil { + t.Fatal(err) + } + + // Initialize the route table with three routes. + s.SetRouteTable([]tcpip.Route{ + {Destination: subnet1, Gateway: "\x00", NIC: 1}, + {Destination: subnet2, Gateway: "\x00", NIC: 1}, + {Destination: subnet3, Gateway: "\x00", NIC: 1}, + }) + + // Remove routes with the specific address. + s.RemoveRoutes(func(r tcpip.Route) bool { + return r.Destination.ID() == addressToRemove + }) + + expected := []tcpip.Route{{Destination: subnet3, Gateway: "\x00", NIC: 1}} + rt := s.GetRouteTable() + if got, want := len(rt), len(expected); got != want { + t.Fatalf("Unexpected route table length got = %d, want = %d", got, want) + } + for i, route := range rt { + if got, want := route, expected[i]; got != want { + t.Fatalf("Unexpected route got = %#v, want = %#v", got, want) + } + } +} + +func TestFindRouteWithForwarding(t *testing.T) { + const ( + nicID1 = 1 + nicID2 = 2 + + nic1Addr = tcpip.Address("\x01") + nic2Addr = tcpip.Address("\x02") + remoteAddr = tcpip.Address("\x03") + ) + + type netCfg struct { + proto tcpip.NetworkProtocolNumber + factory stack.NetworkProtocolFactory + nic1Addr tcpip.Address + nic2Addr tcpip.Address + remoteAddr tcpip.Address + } + + fakeNetCfg := netCfg{ + proto: fakeNetNumber, + factory: fakeNetFactory, + nic1Addr: nic1Addr, + nic2Addr: nic2Addr, + remoteAddr: remoteAddr, + } + + globalIPv6Addr1 := tcpip.Address(net.ParseIP("a::1").To16()) + globalIPv6Addr2 := tcpip.Address(net.ParseIP("a::2").To16()) + + ipv6LinkLocalNIC1WithGlobalRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: llAddr1, + nic2Addr: globalIPv6Addr2, + remoteAddr: globalIPv6Addr1, + } + ipv6GlobalNIC1WithLinkLocalRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: globalIPv6Addr1, + nic2Addr: llAddr1, + remoteAddr: llAddr2, + } + ipv6GlobalNIC1WithLinkLocalMulticastRemote := netCfg{ + proto: ipv6.ProtocolNumber, + factory: ipv6.NewProtocol, + nic1Addr: globalIPv6Addr1, + nic2Addr: globalIPv6Addr2, + remoteAddr: "\xff\x02\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01", + } + + tests := []struct { + name string + + netCfg netCfg + forwardingEnabled bool + + addrNIC tcpip.NICID + localAddr tcpip.Address + + findRouteErr *tcpip.Error + dependentOnForwarding bool + }{ + { + name: "forwarding disabled and localAddr not on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr not on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on specified NIC but route from different NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID1, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: nil, + dependentOnForwarding: true, + }, + { + name: "forwarding disabled and localAddr on specified NIC and route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on specified NIC and route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr not on specified NIC but route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: false, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr not on specified NIC but route from same NIC", + netCfg: fakeNetCfg, + forwardingEnabled: true, + addrNIC: nicID2, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on same NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on same NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and localAddr on different NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: false, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and localAddr on different NIC as route", + netCfg: fakeNetCfg, + forwardingEnabled: true, + localAddr: fakeNetCfg.nic1Addr, + findRouteErr: nil, + dependentOnForwarding: true, + }, + { + name: "forwarding disabled and specified NIC only has link-local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: false, + addrNIC: nicID1, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and specified NIC only has link-local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + addrNIC: nicID1, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and link-local local addr with route on different NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: false, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and link-local local addr with route on same NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic1Addr, + findRouteErr: tcpip.ErrNoRoute, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with route on same NIC", + netCfg: ipv6LinkLocalNIC1WithGlobalRemote, + forwardingEnabled: true, + localAddr: ipv6LinkLocalNIC1WithGlobalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and link-local local addr with route on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and link-local local addr with route on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local multicast remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local multicast remote on different NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic1Addr, + findRouteErr: tcpip.ErrNetworkUnreachable, + dependentOnForwarding: false, + }, + { + name: "forwarding disabled and global local addr with link-local multicast remote on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: false, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + { + name: "forwarding enabled and global local addr with link-local multicast remote on same NIC", + netCfg: ipv6GlobalNIC1WithLinkLocalMulticastRemote, + forwardingEnabled: true, + localAddr: ipv6GlobalNIC1WithLinkLocalMulticastRemote.nic2Addr, + findRouteErr: nil, + dependentOnForwarding: false, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{test.netCfg.factory}, + }) + + ep1 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID1, ep1); err != nil { + t.Fatalf("CreateNIC(%d, _): %s:", nicID1, err) + } + + ep2 := channel.New(1, defaultMTU, "") + if err := s.CreateNIC(nicID2, ep2); err != nil { + t.Fatalf("CreateNIC(%d, _): %s:", nicID2, err) + } + + if err := s.AddAddress(nicID1, test.netCfg.proto, test.netCfg.nic1Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID1, test.netCfg.proto, test.netCfg.nic1Addr, err) + } + + if err := s.AddAddress(nicID2, test.netCfg.proto, test.netCfg.nic2Addr); err != nil { + t.Fatalf("AddAddress(%d, %d, %s): %s", nicID2, test.netCfg.proto, test.netCfg.nic2Addr, err) + } + + if err := s.SetForwarding(test.netCfg.proto, test.forwardingEnabled); err != nil { + t.Fatalf("SetForwarding(%d, %t): %s", test.netCfg.proto, test.forwardingEnabled, err) + } + + s.SetRouteTable([]tcpip.Route{{Destination: test.netCfg.remoteAddr.WithPrefix().Subnet(), NIC: nicID2}}) + + r, err := s.FindRoute(test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, false /* multicastLoop */) + if err != test.findRouteErr { + t.Fatalf("FindRoute(%d, %s, %s, %d, false) = %s, want = %s", test.addrNIC, test.localAddr, test.netCfg.remoteAddr, test.netCfg.proto, err, test.findRouteErr) + } + defer r.Release() + + if test.findRouteErr != nil { + return + } + + if r.LocalAddress != test.localAddr { + t.Errorf("got r.LocalAddress = %s, want = %s", r.LocalAddress, test.localAddr) + } + if r.RemoteAddress != test.netCfg.remoteAddr { + t.Errorf("got r.RemoteAddress = %s, want = %s", r.RemoteAddress, test.netCfg.remoteAddr) + } + + if t.Failed() { + t.FailNow() + } + + // Sending a packet should always go through NIC2 since we only install a + // route to test.netCfg.remoteAddr through NIC2. + data := buffer.View([]byte{1, 2, 3, 4}) + if err := send(r, data); err != nil { + t.Fatalf("send(_, _): %s", err) + } + if n := ep1.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep1", n) + } + pkt, ok := ep2.Read() + if !ok { + t.Fatal("packet not sent through ep2") + } + if pkt.Route.LocalAddress != test.localAddr { + t.Errorf("got pkt.Route.LocalAddress = %s, want = %s", pkt.Route.LocalAddress, test.localAddr) + } + if pkt.Route.RemoteAddress != test.netCfg.remoteAddr { + t.Errorf("got pkt.Route.RemoteAddress = %s, want = %s", pkt.Route.RemoteAddress, test.netCfg.remoteAddr) + } + + if !test.forwardingEnabled || !test.dependentOnForwarding { + return + } + + // Disabling forwarding when the route is dependent on forwarding being + // enabled should make the route invalid. + if err := s.SetForwarding(test.netCfg.proto, false); err != nil { + t.Fatalf("SetForwarding(%d, false): %s", test.netCfg.proto, err) + } + if err := send(r, data); err != tcpip.ErrInvalidEndpointState { + t.Fatalf("got send(_, _) = %s, want = %s", err, tcpip.ErrInvalidEndpointState) + } + if n := ep1.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep1", n) + } + if n := ep2.Drain(); n != 0 { + t.Errorf("got %d unexpected packets from ep2", n) + } + }) + } +} diff --git a/pkg/tcpip/stack/transport_demuxer.go b/pkg/tcpip/stack/transport_demuxer.go index 35e5b1a2e..f183ec6e4 100644 --- a/pkg/tcpip/stack/transport_demuxer.go +++ b/pkg/tcpip/stack/transport_demuxer.go @@ -152,10 +152,10 @@ func (epsByNIC *endpointsByNIC) transportEndpoints() []TransportEndpoint { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, pkt *PacketBuffer) { +func (epsByNIC *endpointsByNIC) handlePacket(id TransportEndpointID, pkt *PacketBuffer) { epsByNIC.mu.RLock() - mpep, ok := epsByNIC.endpoints[r.nic.ID()] + mpep, ok := epsByNIC.endpoints[pkt.NICID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -165,20 +165,20 @@ func (epsByNIC *endpointsByNIC) handlePacket(r *Route, id TransportEndpointID, p // If this is a broadcast or multicast datagram, deliver the datagram to all // endpoints bound to the right device. - if isInboundMulticastOrBroadcast(r) { - mpep.handlePacketAll(r, id, pkt) + if isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { + mpep.handlePacketAll(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. return } // multiPortEndpoints are guaranteed to have at least one element. transEP := selectEndpoint(id, mpep, epsByNIC.seed) if queuedProtocol, mustQueue := mpep.demux.queuedProtocols[protocolIDs{mpep.netProto, mpep.transProto}]; mustQueue { - queuedProtocol.QueuePacket(r, transEP, id, pkt) + queuedProtocol.QueuePacket(transEP, id, pkt) epsByNIC.mu.RUnlock() return } - transEP.HandlePacket(r, id, pkt) + transEP.HandlePacket(id, pkt) epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. } @@ -253,6 +253,8 @@ func (epsByNIC *endpointsByNIC) unregisterEndpoint(bindToDevice tcpip.NICID, t T // based on endpoints IDs. It should only be instantiated via // newTransportDemuxer. type transportDemuxer struct { + stack *Stack + // protocol is immutable. protocol map[protocolIDs]*transportEndpoints queuedProtocols map[protocolIDs]queuedTransportProtocol @@ -262,11 +264,12 @@ type transportDemuxer struct { // the dispatcher to delivery packets to the QueuePacket method instead of // calling HandlePacket directly on the endpoint. type queuedTransportProtocol interface { - QueuePacket(r *Route, ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) + QueuePacket(ep TransportEndpoint, id TransportEndpointID, pkt *PacketBuffer) } func newTransportDemuxer(stack *Stack) *transportDemuxer { d := &transportDemuxer{ + stack: stack, protocol: make(map[protocolIDs]*transportEndpoints), queuedProtocols: make(map[protocolIDs]queuedTransportProtocol), } @@ -377,22 +380,22 @@ func selectEndpoint(id TransportEndpointID, mpep *multiPortEndpoint, seed uint32 return mpep.endpoints[idx] } -func (ep *multiPortEndpoint) handlePacketAll(r *Route, id TransportEndpointID, pkt *PacketBuffer) { +func (ep *multiPortEndpoint) handlePacketAll(id TransportEndpointID, pkt *PacketBuffer) { ep.mu.RLock() queuedProtocol, mustQueue := ep.demux.queuedProtocols[protocolIDs{ep.netProto, ep.transProto}] // HandlePacket takes ownership of pkt, so each endpoint needs // its own copy except for the final one. for _, endpoint := range ep.endpoints[:len(ep.endpoints)-1] { if mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt.Clone()) + queuedProtocol.QueuePacket(endpoint, id, pkt.Clone()) } else { - endpoint.HandlePacket(r, id, pkt.Clone()) + endpoint.HandlePacket(id, pkt.Clone()) } } if endpoint := ep.endpoints[len(ep.endpoints)-1]; mustQueue { - queuedProtocol.QueuePacket(r, endpoint, id, pkt) + queuedProtocol.QueuePacket(endpoint, id, pkt) } else { - endpoint.HandlePacket(r, id, pkt) + endpoint.HandlePacket(id, pkt) } ep.mu.RUnlock() // Don't use defer for performance reasons. } @@ -518,29 +521,29 @@ func (d *transportDemuxer) unregisterEndpoint(netProtos []tcpip.NetworkProtocolN // deliverPacket attempts to find one or more matching transport endpoints, and // then, if matches are found, delivers the packet to them. Returns true if // the packet no longer needs to be handled. -func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { - eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] +func (d *transportDemuxer) deliverPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer, id TransportEndpointID) bool { + eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] if !ok { return false } // If the packet is a UDP broadcast or multicast, then find all matching // transport endpoints. - if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(r) { + if protocol == header.UDPProtocolNumber && isInboundMulticastOrBroadcast(pkt, id.LocalAddress) { eps.mu.RLock() destEPs := eps.findAllEndpointsLocked(id) eps.mu.RUnlock() // Fail if we didn't find at least one matching transport endpoint. if len(destEPs) == 0 { - r.Stats().UDP.UnknownPortErrors.Increment() + d.stack.stats.UDP.UnknownPortErrors.Increment() return false } // handlePacket takes ownership of pkt, so each endpoint needs its own // copy except for the final one. for _, ep := range destEPs[:len(destEPs)-1] { - ep.handlePacket(r, id, pkt.Clone()) + ep.handlePacket(id, pkt.Clone()) } - destEPs[len(destEPs)-1].handlePacket(r, id, pkt) + destEPs[len(destEPs)-1].handlePacket(id, pkt) return true } @@ -548,10 +551,10 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto // destination address, then do nothing further and instruct the caller to do // the same. The network layer handles address validation for specified source // addresses. - if protocol == header.TCPProtocolNumber && (!isSpecified(r.LocalAddress) || !isSpecified(r.RemoteAddress) || isInboundMulticastOrBroadcast(r)) { + if protocol == header.TCPProtocolNumber && (!isSpecified(id.LocalAddress) || !isSpecified(id.RemoteAddress) || isInboundMulticastOrBroadcast(pkt, id.LocalAddress)) { // TCP can only be used to communicate between a single source and a - // single destination; the addresses must be unicast. - r.Stats().TCP.InvalidSegmentsReceived.Increment() + // single destination; the addresses must be unicast.e + d.stack.stats.TCP.InvalidSegmentsReceived.Increment() return true } @@ -560,18 +563,18 @@ func (d *transportDemuxer) deliverPacket(r *Route, protocol tcpip.TransportProto eps.mu.RUnlock() if ep == nil { if protocol == header.UDPProtocolNumber { - r.Stats().UDP.UnknownPortErrors.Increment() + d.stack.stats.UDP.UnknownPortErrors.Increment() } return false } - ep.handlePacket(r, id, pkt) + ep.handlePacket(id, pkt) return true } // deliverRawPacket attempts to deliver the given packet and returns whether it // was delivered successfully. -func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { - eps, ok := d.protocol[protocolIDs{r.NetProto, protocol}] +func (d *transportDemuxer) deliverRawPacket(protocol tcpip.TransportProtocolNumber, pkt *PacketBuffer) bool { + eps, ok := d.protocol[protocolIDs{pkt.NetworkProtocolNumber, protocol}] if !ok { return false } @@ -584,7 +587,7 @@ func (d *transportDemuxer) deliverRawPacket(r *Route, protocol tcpip.TransportPr for _, rawEP := range eps.rawEndpoints { // Each endpoint gets its own copy of the packet for the sake // of save/restore. - rawEP.HandlePacket(r, pkt) + rawEP.HandlePacket(pkt.Clone()) foundRaw = true } eps.mu.RUnlock() @@ -612,7 +615,7 @@ func (d *transportDemuxer) deliverControlPacket(n *NIC, net tcpip.NetworkProtoco } // findTransportEndpoint find a single endpoint that most closely matches the provided id. -func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, r *Route) TransportEndpoint { +func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolNumber, transProto tcpip.TransportProtocolNumber, id TransportEndpointID, nicID tcpip.NICID) TransportEndpoint { eps, ok := d.protocol[protocolIDs{netProto, transProto}] if !ok { return nil @@ -628,7 +631,7 @@ func (d *transportDemuxer) findTransportEndpoint(netProto tcpip.NetworkProtocolN epsByNIC.mu.RLock() eps.mu.RUnlock() - mpep, ok := epsByNIC.endpoints[r.nic.ID()] + mpep, ok := epsByNIC.endpoints[nicID] if !ok { if mpep, ok = epsByNIC.endpoints[0]; !ok { epsByNIC.mu.RUnlock() // Don't use defer for performance reasons. @@ -679,8 +682,8 @@ func (d *transportDemuxer) unregisterRawEndpoint(netProto tcpip.NetworkProtocolN eps.mu.Unlock() } -func isInboundMulticastOrBroadcast(r *Route) bool { - return r.IsInboundBroadcast() || header.IsV4MulticastAddress(r.LocalAddress) || header.IsV6MulticastAddress(r.LocalAddress) +func isInboundMulticastOrBroadcast(pkt *PacketBuffer, localAddr tcpip.Address) bool { + return pkt.NetworkPacketInfo.LocalAddressBroadcast || header.IsV4MulticastAddress(localAddr) || header.IsV6MulticastAddress(localAddr) } func isSpecified(addr tcpip.Address) bool { diff --git a/pkg/tcpip/stack/transport_test.go b/pkg/tcpip/stack/transport_test.go index 62ab6d92f..c457b67a2 100644 --- a/pkg/tcpip/stack/transport_test.go +++ b/pkg/tcpip/stack/transport_test.go @@ -28,7 +28,7 @@ import ( const ( fakeTransNumber tcpip.TransportProtocolNumber = 1 - fakeTransHeaderLen = 3 + fakeTransHeaderLen int = 3 ) // fakeTransportEndpoint is a transport-layer protocol endpoint. It counts @@ -213,20 +213,29 @@ func (*fakeTransportEndpoint) GetRemoteAddress() (tcpip.FullAddress, *tcpip.Erro return tcpip.FullAddress{}, nil } -func (f *fakeTransportEndpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, _ *stack.PacketBuffer) { +func (f *fakeTransportEndpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Increment the number of received packets. f.proto.packetCount++ - if f.acceptQueue != nil { - f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ - TransportEndpointInfo: stack.TransportEndpointInfo{ - ID: f.ID, - NetProto: f.NetProto, - }, - proto: f.proto, - peerAddr: r.RemoteAddress, - route: r.Clone(), - }) + if f.acceptQueue == nil { + return } + + netHdr := pkt.NetworkHeader().View() + route, err := f.proto.stack.FindRoute(pkt.NICID, tcpip.Address(netHdr[dstAddrOffset]), tcpip.Address(netHdr[srcAddrOffset]), pkt.NetworkProtocolNumber, false /* multicastLoop */) + if err != nil { + return + } + route.ResolveWith(pkt.SourceLinkAddress()) + + f.acceptQueue = append(f.acceptQueue, fakeTransportEndpoint{ + TransportEndpointInfo: stack.TransportEndpointInfo{ + ID: f.ID, + NetProto: f.NetProto, + }, + proto: f.proto, + peerAddr: route.RemoteAddress, + route: route, + }) } func (f *fakeTransportEndpoint) HandleControlPacket(stack.TransportEndpointID, stack.ControlType, uint32, *stack.PacketBuffer) { @@ -288,7 +297,7 @@ func (*fakeTransportProtocol) ParsePorts(buffer.View) (src, dst uint16, err *tcp return 0, 0, nil } -func (*fakeTransportProtocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { +func (*fakeTransportProtocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { return stack.UnknownDestinationPacketHandled } diff --git a/pkg/tcpip/tcpip.go b/pkg/tcpip/tcpip.go index d77848d61..3ab2b7654 100644 --- a/pkg/tcpip/tcpip.go +++ b/pkg/tcpip/tcpip.go @@ -356,10 +356,9 @@ func (s *Subnet) IsBroadcast(address Address) bool { return s.Prefix() <= 30 && s.Broadcast() == address } -// Equal returns true if s equals o. -// -// Needed to use cmp.Equal on Subnet as its fields are unexported. +// Equal returns true if this Subnet is equal to the given Subnet. func (s Subnet) Equal(o Subnet) bool { + // If this changes, update Route.Equal accordingly. return s == o } @@ -763,6 +762,10 @@ const ( // endpoint that all packets being written have an IP header and the // endpoint should not attach an IP header. IPHdrIncludedOption + + // AcceptConnOption is used by GetSockOptBool to indicate if the + // socket is a listening socket. + AcceptConnOption ) // SockOptInt represents socket options which values have the int type. @@ -1256,6 +1259,12 @@ func (r Route) String() string { return out.String() } +// Equal returns true if the given Route is equal to this Route. +func (r Route) Equal(to Route) bool { + // NOTE: This relies on the fact that r.Destination == to.Destination + return r == to +} + // TransportProtocolNumber is the number of a transport protocol. type TransportProtocolNumber uint32 @@ -1496,6 +1505,15 @@ type IPStats struct { // IPTablesOutputDropped is the total number of IP packets dropped in // the Output chain. IPTablesOutputDropped *StatCounter + + // OptionTSReceived is the number of Timestamp options seen. + OptionTSReceived *StatCounter + + // OptionRRReceived is the number of Record Route options seen. + OptionRRReceived *StatCounter + + // OptionUnknownReceived is the number of unknown IP options seen. + OptionUnknownReceived *StatCounter } // TCPStats collects TCP-specific stats. diff --git a/pkg/tcpip/tests/integration/BUILD b/pkg/tcpip/tests/integration/BUILD index 34aab32d0..9b0f3b675 100644 --- a/pkg/tcpip/tests/integration/BUILD +++ b/pkg/tcpip/tests/integration/BUILD @@ -10,6 +10,7 @@ go_test( "link_resolution_test.go", "loopback_test.go", "multicast_broadcast_test.go", + "route_test.go", ], deps = [ "//pkg/tcpip", diff --git a/pkg/tcpip/tests/integration/forward_test.go b/pkg/tcpip/tests/integration/forward_test.go index 0dcef7b04..bf7594268 100644 --- a/pkg/tcpip/tests/integration/forward_test.go +++ b/pkg/tcpip/tests/integration/forward_test.go @@ -33,11 +33,6 @@ import ( func TestForwarding(t *testing.T) { const ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - routerNIC1LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07") - routerNIC2LinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") - host1NICID = 1 routerNICID1 = 2 routerNICID2 = 3 @@ -166,6 +161,38 @@ func TestForwarding(t *testing.T) { } }, }, + { + name: "IPv4 host2 server with routerNIC1 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { + ep1, ep1WECH := newEP(t, host2Stack, udp.ProtocolNumber, ipv4.ProtocolNumber) + ep2, ep2WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv4.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: host2IPv4Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: routerNIC1IPv4Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + } + }, + }, + { + name: "IPv6 routerNIC2 server with host1 client", + epAndAddrs: func(t *testing.T, host1Stack, routerStack, host2Stack *stack.Stack) endpointAndAddresses { + ep1, ep1WECH := newEP(t, routerStack, udp.ProtocolNumber, ipv6.ProtocolNumber) + ep2, ep2WECH := newEP(t, host1Stack, udp.ProtocolNumber, ipv6.ProtocolNumber) + return endpointAndAddresses{ + serverEP: ep1, + serverAddr: routerNIC2IPv6Addr.AddressWithPrefix.Address, + serverReadableCH: ep1WECH, + + clientEP: ep2, + clientAddr: host1IPv6Addr.AddressWithPrefix.Address, + clientReadableCH: ep2WECH, + } + }, + }, } for _, test := range tests { @@ -179,8 +206,8 @@ func TestForwarding(t *testing.T) { routerStack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - host1NIC, routerNIC1 := pipe.New(host1NICLinkAddr, routerNIC1LinkAddr) - routerNIC2, host2NIC := pipe.New(routerNIC2LinkAddr, host2NICLinkAddr) + host1NIC, routerNIC1 := pipe.New(linkAddr1, linkAddr2) + routerNIC2, host2NIC := pipe.New(linkAddr3, linkAddr4) if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) @@ -321,12 +348,8 @@ func TestForwarding(t *testing.T) { if err == tcpip.ErrNoLinkAddress { // Wait for link resolution to complete. <-ch - n, _, err = ep.Write(dataPayload, wOpts) - } else if err != nil { - t.Fatalf("ep.Write(_, _): %s", err) } - if err != nil { t.Fatalf("ep.Write(_, _): %s", err) } @@ -343,7 +366,6 @@ func TestForwarding(t *testing.T) { // Wait for the endpoint to be readable. <-ch - var addr tcpip.FullAddress v, _, err := ep.Read(&addr) if err != nil { diff --git a/pkg/tcpip/tests/integration/link_resolution_test.go b/pkg/tcpip/tests/integration/link_resolution_test.go index 6ddcda70c..fe7c1bb3d 100644 --- a/pkg/tcpip/tests/integration/link_resolution_test.go +++ b/pkg/tcpip/tests/integration/link_resolution_test.go @@ -32,32 +32,36 @@ import ( "gvisor.dev/gvisor/pkg/waiter" ) -var ( - host1NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") - host2NICLinkAddr = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") +const ( + linkAddr1 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x06") + linkAddr2 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x07") + linkAddr3 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x08") + linkAddr4 = tcpip.LinkAddress("\x02\x03\x03\x04\x05\x09") +) - host1IPv4Addr = tcpip.ProtocolAddress{ +var ( + ipv4Addr1 = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.1").To4()), PrefixLen: 24, }, } - host2IPv4Addr = tcpip.ProtocolAddress{ + ipv4Addr2 = tcpip.ProtocolAddress{ Protocol: ipv4.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("192.168.0.2").To4()), PrefixLen: 8, }, } - host1IPv6Addr = tcpip.ProtocolAddress{ + ipv6Addr1 = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::1").To16()), PrefixLen: 64, }, } - host2IPv6Addr = tcpip.ProtocolAddress{ + ipv6Addr2 = tcpip.ProtocolAddress{ Protocol: ipv6.ProtocolNumber, AddressWithPrefix: tcpip.AddressWithPrefix{ Address: tcpip.Address(net.ParseIP("a::2").To16()), @@ -89,7 +93,7 @@ func TestPing(t *testing.T) { name: "IPv4 Ping", transProto: icmp.ProtocolNumber4, netProto: ipv4.ProtocolNumber, - remoteAddr: host2IPv4Addr.AddressWithPrefix.Address, + remoteAddr: ipv4Addr2.AddressWithPrefix.Address, icmpBuf: func(t *testing.T) buffer.View { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) @@ -104,7 +108,7 @@ func TestPing(t *testing.T) { name: "IPv6 Ping", transProto: icmp.ProtocolNumber6, netProto: ipv6.ProtocolNumber, - remoteAddr: host2IPv6Addr.AddressWithPrefix.Address, + remoteAddr: ipv6Addr2.AddressWithPrefix.Address, icmpBuf: func(t *testing.T) buffer.View { data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) @@ -127,7 +131,7 @@ func TestPing(t *testing.T) { host1Stack := stack.New(stackOpts) host2Stack := stack.New(stackOpts) - host1NIC, host2NIC := pipe.New(host1NICLinkAddr, host2NICLinkAddr) + host1NIC, host2NIC := pipe.New(linkAddr1, linkAddr2) if err := host1Stack.CreateNIC(host1NICID, ethernet.New(host1NIC)); err != nil { t.Fatalf("host1Stack.CreateNIC(%d, _): %s", host1NICID, err) @@ -143,36 +147,36 @@ func TestPing(t *testing.T) { t.Fatalf("host2Stack.AddAddress(%d, %d, %s): %s", host2NICID, arp.ProtocolNumber, arp.ProtocolAddress, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv4Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv4Addr, err) + if err := host1Stack.AddProtocolAddress(host1NICID, ipv4Addr1); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv4Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv4Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv4Addr, err) + if err := host2Stack.AddProtocolAddress(host2NICID, ipv4Addr2); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv4Addr2, err) } - if err := host1Stack.AddProtocolAddress(host1NICID, host1IPv6Addr); err != nil { - t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, host1IPv6Addr, err) + if err := host1Stack.AddProtocolAddress(host1NICID, ipv6Addr1); err != nil { + t.Fatalf("host1Stack.AddProtocolAddress(%d, %#v): %s", host1NICID, ipv6Addr1, err) } - if err := host2Stack.AddProtocolAddress(host2NICID, host2IPv6Addr); err != nil { - t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, host2IPv6Addr, err) + if err := host2Stack.AddProtocolAddress(host2NICID, ipv6Addr2); err != nil { + t.Fatalf("host2Stack.AddProtocolAddress(%d, %#v): %s", host2NICID, ipv6Addr2, err) } host1Stack.SetRouteTable([]tcpip.Route{ tcpip.Route{ - Destination: host1IPv4Addr.AddressWithPrefix.Subnet(), + Destination: ipv4Addr1.AddressWithPrefix.Subnet(), NIC: host1NICID, }, tcpip.Route{ - Destination: host1IPv6Addr.AddressWithPrefix.Subnet(), + Destination: ipv6Addr1.AddressWithPrefix.Subnet(), NIC: host1NICID, }, }) host2Stack.SetRouteTable([]tcpip.Route{ tcpip.Route{ - Destination: host2IPv4Addr.AddressWithPrefix.Subnet(), + Destination: ipv4Addr2.AddressWithPrefix.Subnet(), NIC: host2NICID, }, tcpip.Route{ - Destination: host2IPv6Addr.AddressWithPrefix.Subnet(), + Destination: ipv6Addr2.AddressWithPrefix.Subnet(), NIC: host2NICID, }, }) diff --git a/pkg/tcpip/tests/integration/loopback_test.go b/pkg/tcpip/tests/integration/loopback_test.go index e8caf09ba..421da1add 100644 --- a/pkg/tcpip/tests/integration/loopback_test.go +++ b/pkg/tcpip/tests/integration/loopback_test.go @@ -204,7 +204,7 @@ func TestLoopbackAcceptAllInSubnet(t *testing.T) { }, }) - wq := waiter.Queue{} + var wq waiter.Queue rep, err := s.NewEndpoint(udp.ProtocolNumber, test.addAddress.Protocol, &wq) if err != nil { t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, test.addAddress.Protocol, err) diff --git a/pkg/tcpip/tests/integration/multicast_broadcast_test.go b/pkg/tcpip/tests/integration/multicast_broadcast_test.go index f1028823b..cdf0459e3 100644 --- a/pkg/tcpip/tests/integration/multicast_broadcast_test.go +++ b/pkg/tcpip/tests/integration/multicast_broadcast_test.go @@ -409,7 +409,7 @@ func TestIncomingMulticastAndBroadcast(t *testing.T) { t.Fatalf("got unexpected address length = %d bytes", l) } - wq := waiter.Queue{} + var wq waiter.Queue ep, err := s.NewEndpoint(udp.ProtocolNumber, netproto, &wq) if err != nil { t.Fatalf("NewEndpoint(%d, %d, _): %s", udp.ProtocolNumber, netproto, err) @@ -447,8 +447,6 @@ func TestReuseAddrAndBroadcast(t *testing.T) { loopbackBroadcast = tcpip.Address("\x7f\xff\xff\xff") ) - data := tcpip.SlicePayload([]byte{1, 2, 3, 4}) - tests := []struct { name string broadcastAddr tcpip.Address @@ -492,16 +490,22 @@ func TestReuseAddrAndBroadcast(t *testing.T) { }, }) + type endpointAndWaiter struct { + ep tcpip.Endpoint + ch chan struct{} + } + var eps []endpointAndWaiter // We create endpoints that bind to both the wildcard address and the // broadcast address to make sure both of these types of "broadcast // interested" endpoints receive broadcast packets. - wq := waiter.Queue{} - var eps []tcpip.Endpoint for _, bindWildcard := range []bool{false, true} { // Create multiple endpoints for each type of "broadcast interested" // endpoint so we can test that all endpoints receive the broadcast // packet. for i := 0; i < 2; i++ { + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) ep, err := s.NewEndpoint(udp.ProtocolNumber, ipv4.ProtocolNumber, &wq) if err != nil { t.Fatalf("(eps[%d]) NewEndpoint(%d, %d, _): %s", len(eps), udp.ProtocolNumber, ipv4.ProtocolNumber, err) @@ -528,7 +532,7 @@ func TestReuseAddrAndBroadcast(t *testing.T) { } } - eps = append(eps, ep) + eps = append(eps, endpointAndWaiter{ep: ep, ch: ch}) } } @@ -539,14 +543,18 @@ func TestReuseAddrAndBroadcast(t *testing.T) { Port: localPort, }, } - if n, _, err := wep.Write(data, writeOpts); err != nil { + data := tcpip.SlicePayload([]byte{byte(i), 2, 3, 4}) + if n, _, err := wep.ep.Write(data, writeOpts); err != nil { t.Fatalf("eps[%d].Write(_, _): %s", i, err) } else if want := int64(len(data)); n != want { t.Fatalf("got eps[%d].Write(_, _) = (%d, nil, nil), want = (%d, nil, nil)", i, n, want) } for j, rep := range eps { - if gotPayload, _, err := rep.Read(nil); err != nil { + // Wait for the endpoint to become readable. + <-rep.ch + + if gotPayload, _, err := rep.ep.Read(nil); err != nil { t.Errorf("(eps[%d] write) eps[%d].Read(nil): %s", i, j, err) } else if diff := cmp.Diff(buffer.View(data), gotPayload); diff != "" { t.Errorf("(eps[%d] write) got UDP payload from eps[%d] mismatch (-want +got):\n%s", i, j, diff) diff --git a/pkg/tcpip/tests/integration/route_test.go b/pkg/tcpip/tests/integration/route_test.go new file mode 100644 index 000000000..02fc47015 --- /dev/null +++ b/pkg/tcpip/tests/integration/route_test.go @@ -0,0 +1,388 @@ +// Copyright 2020 The gVisor Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package integration_test + +import ( + "testing" + + "github.com/google/go-cmp/cmp" + "gvisor.dev/gvisor/pkg/tcpip" + "gvisor.dev/gvisor/pkg/tcpip/buffer" + "gvisor.dev/gvisor/pkg/tcpip/header" + "gvisor.dev/gvisor/pkg/tcpip/link/channel" + "gvisor.dev/gvisor/pkg/tcpip/link/loopback" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv4" + "gvisor.dev/gvisor/pkg/tcpip/network/ipv6" + "gvisor.dev/gvisor/pkg/tcpip/stack" + "gvisor.dev/gvisor/pkg/tcpip/transport/icmp" + "gvisor.dev/gvisor/pkg/tcpip/transport/udp" + "gvisor.dev/gvisor/pkg/waiter" +) + +// TestLocalPing tests pinging a remote that is local the stack. +// +// This tests that a local route is created and packets do not leave the stack. +func TestLocalPing(t *testing.T) { + const ( + nicID = 1 + ipv4Loopback = tcpip.Address("\x7f\x00\x00\x01") + + // icmpDataOffset is the offset to the data in both ICMPv4 and ICMPv6 echo + // request/reply packets. + icmpDataOffset = 8 + ) + + channelEP := func() stack.LinkEndpoint { return channel.New(1, header.IPv6MinimumMTU, "") } + channelEPCheck := func(t *testing.T, e stack.LinkEndpoint) { + channelEP := e.(*channel.Endpoint) + if n := channelEP.Drain(); n != 0 { + t.Fatalf("got channelEP.Drain() = %d, want = 0", n) + } + } + + ipv4ICMPBuf := func(t *testing.T) buffer.View { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 8} + hdr := header.ICMPv4(make([]byte, header.ICMPv4MinimumSize+len(data))) + hdr.SetType(header.ICMPv4Echo) + if n := copy(hdr.Payload(), data[:]); n != len(data) { + t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) + } + return buffer.View(hdr) + } + + ipv6ICMPBuf := func(t *testing.T) buffer.View { + data := [8]byte{1, 2, 3, 4, 5, 6, 7, 9} + hdr := header.ICMPv6(make([]byte, header.ICMPv6MinimumSize+len(data))) + hdr.SetType(header.ICMPv6EchoRequest) + if n := copy(hdr.Payload(), data[:]); n != len(data) { + t.Fatalf("copied %d bytes but expected to copy %d bytes", n, len(data)) + } + return buffer.View(hdr) + } + + tests := []struct { + name string + transProto tcpip.TransportProtocolNumber + netProto tcpip.NetworkProtocolNumber + linkEndpoint func() stack.LinkEndpoint + localAddr tcpip.Address + icmpBuf func(*testing.T) buffer.View + expectedConnectErr *tcpip.Error + checkLinkEndpoint func(t *testing.T, e stack.LinkEndpoint) + }{ + { + name: "IPv4 loopback", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: loopback.New, + localAddr: ipv4Loopback, + icmpBuf: ipv4ICMPBuf, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv6 loopback", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: loopback.New, + localAddr: header.IPv6Loopback, + icmpBuf: ipv6ICMPBuf, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv4 non-loopback", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: channelEP, + localAddr: ipv4Addr.Address, + icmpBuf: ipv4ICMPBuf, + checkLinkEndpoint: channelEPCheck, + }, + { + name: "IPv6 non-loopback", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: channelEP, + localAddr: ipv6Addr.Address, + icmpBuf: ipv6ICMPBuf, + checkLinkEndpoint: channelEPCheck, + }, + { + name: "IPv4 loopback without local address", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: loopback.New, + icmpBuf: ipv4ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv6 loopback without local address", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: loopback.New, + icmpBuf: ipv6ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: func(*testing.T, stack.LinkEndpoint) {}, + }, + { + name: "IPv4 non-loopback without local address", + transProto: icmp.ProtocolNumber4, + netProto: ipv4.ProtocolNumber, + linkEndpoint: channelEP, + icmpBuf: ipv4ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: channelEPCheck, + }, + { + name: "IPv6 non-loopback without local address", + transProto: icmp.ProtocolNumber6, + netProto: ipv6.ProtocolNumber, + linkEndpoint: channelEP, + icmpBuf: ipv6ICMPBuf, + expectedConnectErr: tcpip.ErrNoRoute, + checkLinkEndpoint: channelEPCheck, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + s := stack.New(stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{icmp.NewProtocol4, icmp.NewProtocol6}, + HandleLocal: true, + }) + e := test.linkEndpoint() + if err := s.CreateNIC(nicID, e); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + if len(test.localAddr) != 0 { + if err := s.AddAddress(nicID, test.netProto, test.localAddr); err != nil { + t.Fatalf("s.AddAddress(%d, %d, %s): %s", nicID, test.netProto, test.localAddr, err) + } + } + + var wq waiter.Queue + we, ch := waiter.NewChannelEntry(nil) + wq.EventRegister(&we, waiter.EventIn) + ep, err := s.NewEndpoint(test.transProto, test.netProto, &wq) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d, _): %s", test.transProto, test.netProto, err) + } + defer ep.Close() + + connAddr := tcpip.FullAddress{Addr: test.localAddr} + if err := ep.Connect(connAddr); err != test.expectedConnectErr { + t.Fatalf("got ep.Connect(%#v) = %s, want = %s", connAddr, err, test.expectedConnectErr) + } + + if test.expectedConnectErr != nil { + return + } + + payload := tcpip.SlicePayload(test.icmpBuf(t)) + var wOpts tcpip.WriteOptions + if n, _, err := ep.Write(payload, wOpts); err != nil { + t.Fatalf("ep.Write(%#v, %#v): %s", payload, wOpts, err) + } else if n != int64(len(payload)) { + t.Fatalf("got ep.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", payload, wOpts, n, len(payload)) + } + + // Wait for the endpoint to become readable. + <-ch + + var addr tcpip.FullAddress + v, _, err := ep.Read(&addr) + if err != nil { + t.Fatalf("ep.Read(_): %s", err) + } + if diff := cmp.Diff(v[icmpDataOffset:], buffer.View(payload[icmpDataOffset:])); diff != "" { + t.Errorf("received data mismatch (-want +got):\n%s", diff) + } + if addr.Addr != test.localAddr { + t.Errorf("got addr.Addr = %s, want = %s", addr.Addr, test.localAddr) + } + + test.checkLinkEndpoint(t, e) + }) + } +} + +// TestLocalUDP tests sending UDP packets between two endpoints that are local +// to the stack. +// +// This tests that that packets never leave the stack and the addresses +// used when sending a packet. +func TestLocalUDP(t *testing.T) { + const ( + nicID = 1 + ) + + tests := []struct { + name string + canBePrimaryAddr tcpip.ProtocolAddress + firstPrimaryAddr tcpip.ProtocolAddress + }{ + { + name: "IPv4", + canBePrimaryAddr: ipv4Addr1, + firstPrimaryAddr: ipv4Addr2, + }, + { + name: "IPv6", + canBePrimaryAddr: ipv6Addr1, + firstPrimaryAddr: ipv6Addr2, + }, + } + + subTests := []struct { + name string + addAddress bool + expectedWriteErr *tcpip.Error + }{ + { + name: "Unassigned local address", + addAddress: false, + expectedWriteErr: tcpip.ErrNoRoute, + }, + { + name: "Assigned local address", + addAddress: true, + expectedWriteErr: nil, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + for _, subTest := range subTests { + t.Run(subTest.name, func(t *testing.T) { + stackOpts := stack.Options{ + NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, + TransportProtocols: []stack.TransportProtocolFactory{udp.NewProtocol}, + HandleLocal: true, + } + + s := stack.New(stackOpts) + ep := channel.New(1, header.IPv6MinimumMTU, "") + + if err := s.CreateNIC(nicID, ep); err != nil { + t.Fatalf("s.CreateNIC(%d, _): %s", nicID, err) + } + + if subTest.addAddress { + if err := s.AddProtocolAddressWithOptions(nicID, test.canBePrimaryAddr, stack.CanBePrimaryEndpoint); err != nil { + t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.canBePrimaryAddr, stack.FirstPrimaryEndpoint, err) + } + if err := s.AddProtocolAddressWithOptions(nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint); err != nil { + t.Fatalf("s.AddProtocolAddressWithOptions(%d, %#v, %d): %s", nicID, test.firstPrimaryAddr, stack.FirstPrimaryEndpoint, err) + } + } + + var serverWQ waiter.Queue + serverWE, serverCH := waiter.NewChannelEntry(nil) + serverWQ.EventRegister(&serverWE, waiter.EventIn) + server, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &serverWQ) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err) + } + defer server.Close() + + bindAddr := tcpip.FullAddress{Port: 80} + if err := server.Bind(bindAddr); err != nil { + t.Fatalf("server.Bind(%#v): %s", bindAddr, err) + } + + var clientWQ waiter.Queue + clientWE, clientCH := waiter.NewChannelEntry(nil) + clientWQ.EventRegister(&clientWE, waiter.EventIn) + client, err := s.NewEndpoint(udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, &clientWQ) + if err != nil { + t.Fatalf("s.NewEndpoint(%d, %d): %s", udp.ProtocolNumber, test.firstPrimaryAddr.Protocol, err) + } + defer client.Close() + + serverAddr := tcpip.FullAddress{ + Addr: test.canBePrimaryAddr.AddressWithPrefix.Address, + Port: 80, + } + + clientPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + { + wOpts := tcpip.WriteOptions{ + To: &serverAddr, + } + if n, _, err := client.Write(clientPayload, wOpts); err != subTest.expectedWriteErr { + t.Fatalf("got client.Write(%#v, %#v) = (%d, _, %s_), want = (_, _, %s)", clientPayload, wOpts, n, err, subTest.expectedWriteErr) + } else if subTest.expectedWriteErr != nil { + // Nothing else to test if we expected not to be able to send the + // UDP packet. + return + } else if n != int64(len(clientPayload)) { + t.Fatalf("got client.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", clientPayload, wOpts, n, len(clientPayload)) + } + } + + // Wait for the server endpoint to become readable. + <-serverCH + + var clientAddr tcpip.FullAddress + if v, _, err := server.Read(&clientAddr); err != nil { + t.Fatalf("server.Read(_): %s", err) + } else { + if diff := cmp.Diff(buffer.View(clientPayload), v); diff != "" { + t.Errorf("server read clientPayload mismatch (-want +got):\n%s", diff) + } + if clientAddr.Addr != test.canBePrimaryAddr.AddressWithPrefix.Address { + t.Errorf("got clientAddr.Addr = %s, want = %s", clientAddr.Addr, test.canBePrimaryAddr.AddressWithPrefix.Address) + } + if t.Failed() { + t.FailNow() + } + } + + serverPayload := tcpip.SlicePayload([]byte{1, 2, 3, 4}) + { + wOpts := tcpip.WriteOptions{ + To: &clientAddr, + } + if n, _, err := server.Write(serverPayload, wOpts); err != nil { + t.Fatalf("server.Write(%#v, %#v): %s", serverPayload, wOpts, err) + } else if n != int64(len(serverPayload)) { + t.Fatalf("got server.Write(%#v, %#v) = (%d, _, nil), want = (%d, _, nil)", serverPayload, wOpts, n, len(serverPayload)) + } + } + + // Wait for the client endpoint to become readable. + <-clientCH + + var gotServerAddr tcpip.FullAddress + if v, _, err := client.Read(&gotServerAddr); err != nil { + t.Fatalf("client.Read(_): %s", err) + } else { + if diff := cmp.Diff(buffer.View(serverPayload), v); diff != "" { + t.Errorf("client read serverPayload mismatch (-want +got):\n%s", diff) + } + if gotServerAddr.Addr != serverAddr.Addr { + t.Errorf("got gotServerAddr.Addr = %s, want = %s", gotServerAddr.Addr, serverAddr.Addr) + } + if t.Failed() { + t.FailNow() + } + } + }) + } + }) + } +} diff --git a/pkg/tcpip/transport/icmp/endpoint.go b/pkg/tcpip/transport/icmp/endpoint.go index 41eb0ca44..763cd8f84 100644 --- a/pkg/tcpip/transport/icmp/endpoint.go +++ b/pkg/tcpip/transport/icmp/endpoint.go @@ -378,7 +378,7 @@ func (e *endpoint) SetSockOptInt(opt tcpip.SockOptInt, v int) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { - case tcpip.KeepaliveEnabledOption: + case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: return false, nil default: @@ -755,7 +755,7 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Only accept echo replies. switch e.NetProto { case header.IPv4ProtocolNumber: @@ -800,7 +800,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Push new packet into receive list and increment the buffer size. packet := &icmpPacket{ senderAddress: tcpip.FullAddress{ - NIC: r.NICID(), + NIC: pkt.NICID, Addr: id.RemoteAddress, }, } diff --git a/pkg/tcpip/transport/icmp/protocol.go b/pkg/tcpip/transport/icmp/protocol.go index 87d510f96..3820e5dc7 100644 --- a/pkg/tcpip/transport/icmp/protocol.go +++ b/pkg/tcpip/transport/icmp/protocol.go @@ -101,7 +101,7 @@ func (p *protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) // HandleUnknownDestinationPacket handles packets targeted at this protocol but // that don't match any existing endpoint. -func (*protocol) HandleUnknownDestinationPacket(*stack.Route, stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { +func (*protocol) HandleUnknownDestinationPacket(stack.TransportEndpointID, *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { return stack.UnknownDestinationPacketHandled } diff --git a/pkg/tcpip/transport/packet/endpoint.go b/pkg/tcpip/transport/packet/endpoint.go index 072601d2d..31831a6d8 100644 --- a/pkg/tcpip/transport/packet/endpoint.go +++ b/pkg/tcpip/transport/packet/endpoint.go @@ -389,7 +389,12 @@ func (ep *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (*endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { - return false, tcpip.ErrNotSupported + switch opt { + case tcpip.AcceptConnOption: + return false, nil + default: + return false, tcpip.ErrNotSupported + } } // GetSockOptInt implements tcpip.Endpoint.GetSockOptInt. diff --git a/pkg/tcpip/transport/raw/endpoint.go b/pkg/tcpip/transport/raw/endpoint.go index e37c00523..7b6a87ba9 100644 --- a/pkg/tcpip/transport/raw/endpoint.go +++ b/pkg/tcpip/transport/raw/endpoint.go @@ -601,7 +601,7 @@ func (e *endpoint) GetSockOpt(opt tcpip.GettableSocketOption) *tcpip.Error { // GetSockOptBool implements tcpip.Endpoint.GetSockOptBool. func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { switch opt { - case tcpip.KeepaliveEnabledOption: + case tcpip.KeepaliveEnabledOption, tcpip.AcceptConnOption: return false, nil case tcpip.IPHdrIncludedOption: @@ -646,7 +646,7 @@ func (e *endpoint) GetSockOptInt(opt tcpip.SockOptInt) (int, *tcpip.Error) { } // HandlePacket implements stack.RawTransportEndpoint.HandlePacket. -func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(pkt *stack.PacketBuffer) { e.rcvMu.Lock() // Drop the packet if our buffer is currently full or if this is an unassociated @@ -671,14 +671,16 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { return } + remoteAddr := pkt.Network().SourceAddress() + if e.bound { // If bound to a NIC, only accept data for that NIC. - if e.BindNICID != 0 && e.BindNICID != route.NICID() { + if e.BindNICID != 0 && e.BindNICID != pkt.NICID { e.rcvMu.Unlock() return } // If bound to an address, only accept data for that address. - if e.BindAddr != "" && e.BindAddr != route.RemoteAddress { + if e.BindAddr != "" && e.BindAddr != remoteAddr { e.rcvMu.Unlock() return } @@ -686,7 +688,7 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { // If connected, only accept packets from the remote address we // connected to. - if e.connected && e.route.RemoteAddress != route.RemoteAddress { + if e.connected && e.route.RemoteAddress != remoteAddr { e.rcvMu.Unlock() return } @@ -696,8 +698,8 @@ func (e *endpoint) HandlePacket(route *stack.Route, pkt *stack.PacketBuffer) { // Push new packet into receive list and increment the buffer size. packet := &rawPacket{ senderAddr: tcpip.FullAddress{ - NIC: route.NICID(), - Addr: route.RemoteAddress, + NIC: pkt.NICID, + Addr: remoteAddr, }, } diff --git a/pkg/tcpip/transport/raw/endpoint_state.go b/pkg/tcpip/transport/raw/endpoint_state.go index 33bfb56cd..7d97cbdc7 100644 --- a/pkg/tcpip/transport/raw/endpoint_state.go +++ b/pkg/tcpip/transport/raw/endpoint_state.go @@ -37,57 +37,57 @@ func (p *rawPacket) loadData(data buffer.VectorisedView) { } // beforeSave is invoked by stateify. -func (ep *endpoint) beforeSave() { +func (e *endpoint) beforeSave() { // Stop incoming packets from being handled (and mutate endpoint state). // The lock will be released after saveRcvBufSizeMax(), which would have - // saved ep.rcvBufSizeMax and set it to 0 to continue blocking incoming + // saved e.rcvBufSizeMax and set it to 0 to continue blocking incoming // packets. - ep.rcvMu.Lock() + e.rcvMu.Lock() } // saveRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) saveRcvBufSizeMax() int { - max := ep.rcvBufSizeMax +func (e *endpoint) saveRcvBufSizeMax() int { + max := e.rcvBufSizeMax // Make sure no new packets will be handled regardless of the lock. - ep.rcvBufSizeMax = 0 + e.rcvBufSizeMax = 0 // Release the lock acquired in beforeSave() so regular endpoint closing // logic can proceed after save. - ep.rcvMu.Unlock() + e.rcvMu.Unlock() return max } // loadRcvBufSizeMax is invoked by stateify. -func (ep *endpoint) loadRcvBufSizeMax(max int) { - ep.rcvBufSizeMax = max +func (e *endpoint) loadRcvBufSizeMax(max int) { + e.rcvBufSizeMax = max } // afterLoad is invoked by stateify. -func (ep *endpoint) afterLoad() { - stack.StackFromEnv.RegisterRestoredEndpoint(ep) +func (e *endpoint) afterLoad() { + stack.StackFromEnv.RegisterRestoredEndpoint(e) } // Resume implements tcpip.ResumableEndpoint.Resume. -func (ep *endpoint) Resume(s *stack.Stack) { - ep.stack = s +func (e *endpoint) Resume(s *stack.Stack) { + e.stack = s // If the endpoint is connected, re-connect. - if ep.connected { + if e.connected { var err *tcpip.Error - ep.route, err = ep.stack.FindRoute(ep.RegisterNICID, ep.BindAddr, ep.route.RemoteAddress, ep.NetProto, false) + e.route, err = e.stack.FindRoute(e.RegisterNICID, e.BindAddr, e.route.RemoteAddress, e.NetProto, false) if err != nil { panic(err) } } // If the endpoint is bound, re-bind. - if ep.bound { - if ep.stack.CheckLocalAddress(ep.RegisterNICID, ep.NetProto, ep.BindAddr) == 0 { + if e.bound { + if e.stack.CheckLocalAddress(e.RegisterNICID, e.NetProto, e.BindAddr) == 0 { panic(tcpip.ErrBadLocalAddress) } } - if ep.associated { - if err := ep.stack.RegisterRawTransportEndpoint(ep.RegisterNICID, ep.NetProto, ep.TransProto, ep); err != nil { + if e.associated { + if err := e.stack.RegisterRawTransportEndpoint(e.RegisterNICID, e.NetProto, e.TransProto, e); err != nil { panic(err) } } diff --git a/pkg/tcpip/transport/tcp/accept.go b/pkg/tcpip/transport/tcp/accept.go index b706438bd..47982ca41 100644 --- a/pkg/tcpip/transport/tcp/accept.go +++ b/pkg/tcpip/transport/tcp/accept.go @@ -199,18 +199,25 @@ func (l *listenContext) isCookieValid(id stack.TransportEndpointID, cookie seqnu // createConnectingEndpoint creates a new endpoint in a connecting state, with // the connection parameters given by the arguments. -func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) *endpoint { +func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, irs seqnum.Value, rcvdSynOpts *header.TCPSynOptions, queue *waiter.Queue) (*endpoint, *tcpip.Error) { // Create a new endpoint. netProto := l.netProto if netProto == 0 { - netProto = s.route.NetProto + netProto = s.netProto } + + route, err := l.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return nil, err + } + route.ResolveWith(s.remoteLinkAddr) + n := newEndpoint(l.stack, netProto, queue) n.v6only = l.v6Only n.ID = s.id - n.boundNICID = s.route.NICID() - n.route = s.route.Clone() - n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.route.NetProto} + n.boundNICID = s.nicID + n.route = route + n.effectiveNetProtos = []tcpip.NetworkProtocolNumber{s.netProto} n.rcvBufSize = int(l.rcvWnd) n.amss = calculateAdvertisedMSS(n.userMSS, n.route) n.setEndpointState(StateConnecting) @@ -225,7 +232,7 @@ func (l *listenContext) createConnectingEndpoint(s *segment, iss seqnum.Value, i // window to grow to a really large value. n.rcvAutoParams.prevCopied = n.initialReceiveWindow() - return n + return n, nil } // createEndpointAndPerformHandshake creates a new endpoint in connected state @@ -236,7 +243,10 @@ func (l *listenContext) createEndpointAndPerformHandshake(s *segment, opts *head // Create new endpoint. irs := s.sequenceNumber isn := generateSecureISN(s.id, l.stack.Seed()) - ep := l.createConnectingEndpoint(s, isn, irs, opts, queue) + ep, err := l.createConnectingEndpoint(s, isn, irs, opts, queue) + if err != nil { + return nil, err + } // Lock the endpoint before registering to ensure that no out of // band changes are possible due to incoming packets etc till @@ -425,20 +435,17 @@ func (e *endpoint) notifyAborted() { // cookies to accept connections. func (e *endpoint) handleSynSegment(ctx *listenContext, s *segment, opts *header.TCPSynOptions) { defer ctx.synRcvdCount.dec() - defer func() { - e.mu.Lock() - e.decSynRcvdCount() - e.mu.Unlock() - }() defer s.decRef() n, err := ctx.createEndpointAndPerformHandshake(s, opts, &waiter.Queue{}, e.owner) if err != nil { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() + e.decSynRcvdCount() return } ctx.removePendingEndpoint(n) + e.decSynRcvdCount() n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() @@ -456,7 +463,9 @@ func (e *endpoint) incSynRcvdCount() bool { } func (e *endpoint) decSynRcvdCount() { + e.mu.Lock() e.synRcvdCount-- + e.mu.Unlock() } func (e *endpoint) acceptQueueIsFull() bool { @@ -468,7 +477,7 @@ func (e *endpoint) acceptQueueIsFull() bool { // handleListenSegment is called when a listening endpoint receives a segment // and needs to handle it. -func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { +func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) *tcpip.Error { e.rcvListMu.Lock() rcvClosed := e.rcvClosed e.rcvListMu.Unlock() @@ -478,8 +487,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // RFC 793 section 3.4 page 35 (figure 12) outlines that a RST // must be sent in response to a SYN-ACK while in the listen // state to prevent completing a handshake from an old SYN. - replyWithReset(s, e.sendTOS, e.ttl) - return + return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } switch { @@ -493,13 +501,13 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { if !e.acceptQueueIsFull() && e.incSynRcvdCount() { s.incRef() go e.handleSynSegment(ctx, s, &opts) // S/R-SAFE: synRcvdCount is the barrier. - return + return nil } ctx.synRcvdCount.dec() e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } else { // If cookies are in use but the endpoint accept queue // is full then drop the syn. @@ -507,10 +515,17 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.ListenOverflowSynDrop.Increment() e.stats.ReceiveErrors.ListenOverflowSynDrop.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } cookie := ctx.createCookie(s.id, s.sequenceNumber, encodeMSS(opts.MSS)) + route, err := e.stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + route.ResolveWith(s.remoteLinkAddr) + // Send SYN without window scaling because we currently // don't encode this information in the cookie. // @@ -524,9 +539,9 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { TS: opts.TS, TSVal: tcpTimeStamp(time.Now(), timeStampOffset()), TSEcr: opts.TSVal, - MSS: calculateAdvertisedMSS(e.userMSS, s.route), + MSS: calculateAdvertisedMSS(e.userMSS, route), } - e.sendSynTCP(&s.route, tcpFields{ + fields := tcpFields{ id: s.id, ttl: e.ttl, tos: e.sendTOS, @@ -534,8 +549,12 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { seq: cookie, ack: s.sequenceNumber + 1, rcvWnd: ctx.rcvWnd, - }, synOpts) + } + if err := e.sendSynTCP(&route, fields, synOpts); err != nil { + return err + } e.stack.Stats().TCP.ListenOverflowSynCookieSent.Increment() + return nil } case (s.flags & header.TCPFlagAck) != 0: @@ -548,7 +567,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.ListenOverflowAckDrop.Increment() e.stats.ReceiveErrors.ListenOverflowAckDrop.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } if !ctx.synRcvdCount.synCookiesInUse() { @@ -567,8 +586,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { // The only time we should reach here when a connection // was opened and closed really quickly and a delayed // ACK was received from the sender. - replyWithReset(s, e.sendTOS, e.ttl) - return + return replyWithReset(e.stack, s, e.sendTOS, e.ttl) } iss := s.ackNumber - 1 @@ -588,7 +606,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { if !ok || int(data) >= len(mssTable) { e.stack.Stats().TCP.ListenOverflowInvalidSynCookieRcvd.Increment() e.stack.Stats().DroppedPackets.Increment() - return + return nil } e.stack.Stats().TCP.ListenOverflowSynCookieRcvd.Increment() // Create newly accepted endpoint and deliver it. @@ -609,7 +627,10 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { rcvdSynOptions.TSEcr = s.parsedOptions.TSEcr } - n := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + n, err := ctx.createConnectingEndpoint(s, iss, irs, rcvdSynOptions, &waiter.Queue{}) + if err != nil { + return err + } n.mu.Lock() @@ -623,7 +644,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - return + return nil } // Register new endpoint so that packets are routed to it. @@ -633,7 +654,7 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { e.stack.Stats().TCP.FailedConnectionAttempts.Increment() e.stats.FailedConnectionAttempts.Increment() - return + return err } n.isRegistered = true @@ -671,12 +692,16 @@ func (e *endpoint) handleListenSegment(ctx *listenContext, s *segment) { n.startAcceptedLoop() e.stack.Stats().TCP.PassiveConnectionOpenings.Increment() go e.deliverAccepted(n) + return nil + + default: + return nil } } // protocolListenLoop is the main loop of a listening TCP endpoint. It runs in // its own goroutine and is responsible for handling connection requests. -func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { +func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) { e.mu.Lock() v6Only := e.v6only ctx := newListenContext(e.stack, e, rcvWnd, v6Only, e.NetProto) @@ -715,12 +740,14 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { case wakerForNotification: n := e.fetchNotifications() if n¬ifyClose != 0 { - return nil + return } if n¬ifyDrain != 0 { for !e.segmentQueue.empty() { s := e.segmentQueue.dequeue() - e.handleListenSegment(ctx, s) + // TODO(gvisor.dev/issue/4690): Better handle errors instead of + // silently dropping. + _ = e.handleListenSegment(ctx, s) s.decRef() } close(e.drainDone) @@ -739,7 +766,9 @@ func (e *endpoint) protocolListenLoop(rcvWnd seqnum.Size) *tcpip.Error { break } - e.handleListenSegment(ctx, s) + // TODO(gvisor.dev/issue/4690): Better handle errors instead of + // silently dropping. + _ = e.handleListenSegment(ctx, s) s.decRef() } diff --git a/pkg/tcpip/transport/tcp/connect.go b/pkg/tcpip/transport/tcp/connect.go index 0aaef495d..2facbebec 100644 --- a/pkg/tcpip/transport/tcp/connect.go +++ b/pkg/tcpip/transport/tcp/connect.go @@ -293,9 +293,9 @@ func (h *handshake) synSentState(s *segment) *tcpip.Error { MSS: amss, } if ttl == 0 { - ttl = s.route.DefaultTTL() + ttl = h.ep.route.DefaultTTL() } - h.ep.sendSynTCP(&s.route, tcpFields{ + h.ep.sendSynTCP(&h.ep.route, tcpFields{ id: h.ep.ID, ttl: ttl, tos: h.ep.sendTOS, @@ -356,7 +356,7 @@ func (h *handshake) synRcvdState(s *segment) *tcpip.Error { SACKPermitted: h.ep.sackPermitted, MSS: h.ep.amss, } - h.ep.sendSynTCP(&s.route, tcpFields{ + h.ep.sendSynTCP(&h.ep.route, tcpFields{ id: h.ep.ID, ttl: h.ep.ttl, tos: h.ep.sendTOS, @@ -496,7 +496,9 @@ func (h *handshake) resolveRoute() *tcpip.Error { } // Wait for notification. - index, _ = s.Fetch(true) + h.ep.mu.Unlock() + index, _ = s.Fetch(true /* block */) + h.ep.mu.Lock() } } @@ -566,8 +568,10 @@ func (h *handshake) execute() *tcpip.Error { }, synOpts) for h.state != handshakeCompleted { + // Unlock before blocking, and reacquire again afterwards (h.ep.mu is held + // throughout handshake processing). h.ep.mu.Unlock() - index, _ := s.Fetch(true) + index, _ := s.Fetch(true /* block */) h.ep.mu.Lock() switch index { @@ -767,7 +771,7 @@ func buildTCPHdr(r *stack.Route, tf tcpFields, pkt *stack.PacketBuffer, gso *sta // TCP header, then the kernel calculate a checksum of the // header and data and get the right sum of the TCP packet. tcp.SetChecksum(xsum) - } else if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 { + } else if r.RequiresTXTransportChecksum() { xsum = header.ChecksumVV(pkt.Data, xsum) tcp.SetChecksum(^tcp.CalculateChecksum(xsum)) } @@ -1040,13 +1044,13 @@ func (e *endpoint) transitionToStateCloseLocked() { // only when the endpoint is in StateClose and we want to deliver the segment // to any other listening endpoint. We reply with RST if we cannot find one. func (e *endpoint) tryDeliverSegmentFromClosedEndpoint(s *segment) { - ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, &s.route) + ep := e.stack.FindTransportEndpoint(e.NetProto, e.TransProto, e.ID, s.nicID) if ep == nil && e.NetProto == header.IPv6ProtocolNumber && e.EndpointInfo.TransportEndpointInfo.ID.LocalAddress.To4() != "" { // Dual-stack socket, try IPv4. - ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, &s.route) + ep = e.stack.FindTransportEndpoint(header.IPv4ProtocolNumber, e.TransProto, e.ID, s.nicID) } if ep == nil { - replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) + replyWithReset(e.stack, s, stack.DefaultTOS, 0 /* ttl */) s.decRef() return } @@ -1366,7 +1370,9 @@ func (e *endpoint) protocolMainLoop(handshake bool, wakerInitDone chan<- struct{ drained := e.drainDone != nil if drained { close(e.drainDone) + e.mu.Unlock() <-e.undrain + e.mu.Lock() } // Set up the functions that will be called when the main protocol loop @@ -1535,7 +1541,7 @@ loop: } e.mu.Unlock() - v, _ := s.Fetch(true) + v, _ := s.Fetch(true /* block */) e.mu.Lock() // We need to double check here because the notification may be @@ -1620,7 +1626,7 @@ func (e *endpoint) handleTimeWaitSegments() (extendTimeWait bool, reuseTW func() netProtos = []tcpip.NetworkProtocolNumber{header.IPv4ProtocolNumber, header.IPv6ProtocolNumber} } for _, netProto := range netProtos { - if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, &s.route); listenEP != nil { + if listenEP := e.stack.FindTransportEndpoint(netProto, info.TransProto, newID, s.nicID); listenEP != nil { tcpEP := listenEP.(*endpoint) if EndpointState(tcpEP.State()) == StateListen { reuseTW = func() { @@ -1683,7 +1689,7 @@ func (e *endpoint) doTimeWait() (twReuse func()) { for { e.mu.Unlock() - v, _ := s.Fetch(true) + v, _ := s.Fetch(true /* block */) e.mu.Lock() switch v { case newSegment: diff --git a/pkg/tcpip/transport/tcp/dispatcher.go b/pkg/tcpip/transport/tcp/dispatcher.go index 98aecab9e..21162f01a 100644 --- a/pkg/tcpip/transport/tcp/dispatcher.go +++ b/pkg/tcpip/transport/tcp/dispatcher.go @@ -172,10 +172,11 @@ func (d *dispatcher) wait() { d.wg.Wait() } -func (d *dispatcher) queuePacket(r *stack.Route, stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (d *dispatcher) queuePacket(stackEP stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { ep := stackEP.(*endpoint) - s := newSegment(r, id, pkt) - if !s.parse() { + + s := newIncomingSegment(id, pkt) + if !s.parse(pkt.RXTransportChecksumValidated) { ep.stack.Stats().MalformedRcvdPackets.Increment() ep.stack.Stats().TCP.InvalidSegmentsReceived.Increment() ep.stats.ReceiveErrors.MalformedPacketsReceived.Increment() diff --git a/pkg/tcpip/transport/tcp/dual_stack_test.go b/pkg/tcpip/transport/tcp/dual_stack_test.go index 560b4904c..a6f25896b 100644 --- a/pkg/tcpip/transport/tcp/dual_stack_test.go +++ b/pkg/tcpip/transport/tcp/dual_stack_test.go @@ -236,6 +236,25 @@ func TestV6ConnectWhenBoundToWildcard(t *testing.T) { testV6Connect(t, c) } +func TestStackV6OnlyConnectWhenBoundToWildcard(t *testing.T) { + c := context.NewWithOpts(t, context.Options{ + EnableV6: true, + MTU: defaultMTU, + }) + defer c.Cleanup() + + // Create a v6 endpoint but don't set the v6-only TCP option. + c.CreateV6Endpoint(false) + + // Bind to wildcard. + if err := c.EP.Bind(tcpip.FullAddress{}); err != nil { + t.Fatalf("Bind failed: %v", err) + } + + // Test the connection request. + testV6Connect(t, c) +} + func TestV6ConnectWhenBoundToLocalAddress(t *testing.T) { c := context.New(t, defaultMTU) defer c.Cleanup() diff --git a/pkg/tcpip/transport/tcp/endpoint.go b/pkg/tcpip/transport/tcp/endpoint.go index 3bcd3923a..258f9f1bb 100644 --- a/pkg/tcpip/transport/tcp/endpoint.go +++ b/pkg/tcpip/transport/tcp/endpoint.go @@ -721,9 +721,9 @@ func (e *endpoint) LockUser() { for { // Try first if the sock is locked then check if it's owned // by another user goroutine if not then we spin, otherwise - // we just goto sleep on the Lock() and wait. + // we just go to sleep on the Lock() and wait. if !e.mu.TryLock() { - // If socket is owned by the user then just goto sleep + // If socket is owned by the user then just go to sleep // as the lock could be held for a reasonably long time. if atomic.LoadUint32(&e.ownedByUser) == 1 { e.mu.Lock() @@ -1425,7 +1425,7 @@ func (e *endpoint) Write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c queueAndSend := func() (int64, <-chan struct{}, *tcpip.Error) { // Add data to the send queue. - s := newSegmentFromView(&e.route, e.ID, v) + s := newOutgoingSegment(e.ID, v) e.sndBufUsed += len(v) e.sndBufInQueue += seqnum.Size(len(v)) e.sndQueue.PushBack(s) @@ -1999,6 +1999,12 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { case tcpip.MulticastLoopOption: return true, nil + case tcpip.AcceptConnOption: + e.LockUser() + defer e.UnlockUser() + + return e.EndpointState() == StateListen, nil + default: return false, tcpip.ErrUnknownProtocolOption } @@ -2310,7 +2316,7 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc // done yet) or the reservation was freed between the check above and // the FindTransportEndpoint below. But rather than retry the same port // we just skip it and move on. - transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, &r) + transEP := e.stack.FindTransportEndpoint(netProto, ProtocolNumber, transEPID, r.NICID()) if transEP == nil { // ReservePort failed but there is no registered endpoint with // demuxer. Which indicates there is at least some endpoint that has @@ -2379,7 +2385,6 @@ func (e *endpoint) connect(addr tcpip.FullAddress, handshake bool, run bool) *tc for _, l := range []segmentList{e.segmentQueue.list, e.sndQueue, e.snd.writeList} { for s := l.Front(); s != nil; s = s.Next() { s.id = e.ID - s.route = r.Clone() e.sndWaker.Assert() } } @@ -2445,7 +2450,7 @@ func (e *endpoint) shutdownLocked(flags tcpip.ShutdownFlags) *tcpip.Error { } // Queue fin segment. - s := newSegmentFromView(&e.route, e.ID, nil) + s := newOutgoingSegment(e.ID, nil) e.sndQueue.PushBack(s) e.sndBufInQueue++ // Mark endpoint as closed. @@ -2627,14 +2632,16 @@ func (e *endpoint) bindLocked(addr tcpip.FullAddress) (err *tcpip.Error) { return err } - // Expand netProtos to include v4 and v6 if the caller is binding to a - // wildcard (empty) address, and this is an IPv6 endpoint with v6only - // set to false. netProtos := []tcpip.NetworkProtocolNumber{netProto} - if netProto == header.IPv6ProtocolNumber && !e.v6only && addr.Addr == "" { - netProtos = []tcpip.NetworkProtocolNumber{ - header.IPv6ProtocolNumber, - header.IPv4ProtocolNumber, + + // Expand netProtos to include v4 and v6 under dual-stack if the caller is + // binding to a wildcard (empty) address, and this is an IPv6 endpoint with + // v6only set to false. + if netProto == header.IPv6ProtocolNumber { + stackHasV4 := e.stack.CheckNetworkProtocol(header.IPv4ProtocolNumber) + alsoBindToV4 := !e.v6only && addr.Addr == "" && stackHasV4 + if alsoBindToV4 { + netProtos = append(netProtos, header.IPv4ProtocolNumber) } } @@ -2715,7 +2722,7 @@ func (e *endpoint) getRemoteAddress() tcpip.FullAddress { } } -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (*endpoint) HandlePacket(stack.TransportEndpointID, *stack.PacketBuffer) { // TCP HandlePacket is not required anymore as inbound packets first // land at the Dispatcher which then can either delivery using the // worker go routine or directly do the invoke the tcp processing inline @@ -3074,9 +3081,9 @@ func (e *endpoint) initHardwareGSO() { } func (e *endpoint) initGSO() { - if e.route.Capabilities()&stack.CapabilityHardwareGSO != 0 { + if e.route.HasHardwareGSOCapability() { e.initHardwareGSO() - } else if e.route.Capabilities()&stack.CapabilitySoftwareGSO != 0 { + } else if e.route.HasSoftwareGSOCapability() { e.gso = &stack.GSO{ MaxSize: e.route.GSOMaxSize(), Type: stack.GSOSW, diff --git a/pkg/tcpip/transport/tcp/endpoint_state.go b/pkg/tcpip/transport/tcp/endpoint_state.go index b25431467..2bcc5e1c2 100644 --- a/pkg/tcpip/transport/tcp/endpoint_state.go +++ b/pkg/tcpip/transport/tcp/endpoint_state.go @@ -53,8 +53,8 @@ func (e *endpoint) beforeSave() { switch { case epState == StateInitial || epState == StateBound: case epState.connected() || epState.handshake(): - if e.route.Capabilities()&stack.CapabilitySaveRestore == 0 { - if e.route.Capabilities()&stack.CapabilityDisconnectOk == 0 { + if !e.route.HasSaveRestoreCapability() { + if !e.route.HasDisconncetOkCapability() { panic(tcpip.ErrSaveRejection{fmt.Errorf("endpoint cannot be saved in connected state: local %v:%d, remote %v:%d", e.ID.LocalAddress, e.ID.LocalPort, e.ID.RemoteAddress, e.ID.RemotePort)}) } e.resetConnectionLocked(tcpip.ErrConnectionAborted) diff --git a/pkg/tcpip/transport/tcp/forwarder.go b/pkg/tcpip/transport/tcp/forwarder.go index 070b634b4..0664789da 100644 --- a/pkg/tcpip/transport/tcp/forwarder.go +++ b/pkg/tcpip/transport/tcp/forwarder.go @@ -30,6 +30,8 @@ import ( // The canonical way of using it is to pass the Forwarder.HandlePacket function // to stack.SetTransportProtocolHandler. type Forwarder struct { + stack *stack.Stack + maxInFlight int handler func(*ForwarderRequest) @@ -48,6 +50,7 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward rcvWnd = DefaultReceiveBufferSize } return &Forwarder{ + stack: s, maxInFlight: maxInFlight, handler: handler, inFlight: make(map[stack.TransportEndpointID]struct{}), @@ -61,12 +64,12 @@ func NewForwarder(s *stack.Stack, rcvWnd, maxInFlight int, handler func(*Forward // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { - s := newSegment(r, id, pkt) +func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { + s := newIncomingSegment(id, pkt) defer s.decRef() // We only care about well-formed SYN packets. - if !s.parse() || !s.csumValid || s.flags != header.TCPFlagSyn { + if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid || s.flags != header.TCPFlagSyn { return false } @@ -128,9 +131,8 @@ func (r *ForwarderRequest) Complete(sendReset bool) { delete(r.forwarder.inFlight, r.segment.id) r.forwarder.mu.Unlock() - // If the caller requested, send a reset. if sendReset { - replyWithReset(r.segment, stack.DefaultTOS, r.segment.route.DefaultTTL()) + replyWithReset(r.forwarder.stack, r.segment, stack.DefaultTOS, 0 /* ttl */) } // Release all resources. diff --git a/pkg/tcpip/transport/tcp/protocol.go b/pkg/tcpip/transport/tcp/protocol.go index 5bce73605..2329aca4b 100644 --- a/pkg/tcpip/transport/tcp/protocol.go +++ b/pkg/tcpip/transport/tcp/protocol.go @@ -187,8 +187,8 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // to a specific processing queue. Each queue is serviced by its own processor // goroutine which is responsible for dequeuing and doing full TCP dispatch of // the packet. -func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { - p.dispatcher.queuePacket(r, ep, id, pkt) +func (p *protocol) QueuePacket(ep stack.TransportEndpoint, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { + p.dispatcher.queuePacket(ep, id, pkt) } // HandleUnknownDestinationPacket handles packets targeted at this protocol but @@ -198,24 +198,32 @@ func (p *protocol) QueuePacket(r *stack.Route, ep stack.TransportEndpoint, id st // a reset is sent in response to any incoming segment except another reset. In // particular, SYNs addressed to a non-existent connection are rejected by this // means." - -func (*protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { - s := newSegment(r, id, pkt) +func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { + s := newIncomingSegment(id, pkt) defer s.decRef() - if !s.parse() || !s.csumValid { + if !s.parse(pkt.RXTransportChecksumValidated) || !s.csumValid { return stack.UnknownDestinationPacketMalformed } if !s.flagIsSet(header.TCPFlagRst) { - replyWithReset(s, stack.DefaultTOS, s.route.DefaultTTL()) + replyWithReset(p.stack, s, stack.DefaultTOS, 0) } return stack.UnknownDestinationPacketHandled } // replyWithReset replies to the given segment with a reset segment. -func replyWithReset(s *segment, tos, ttl uint8) { +// +// If the passed TTL is 0, then the route's default TTL will be used. +func replyWithReset(stack *stack.Stack, s *segment, tos, ttl uint8) *tcpip.Error { + route, err := stack.FindRoute(s.nicID, s.dstAddr, s.srcAddr, s.netProto, false /* multicastLoop */) + if err != nil { + return err + } + defer route.Release() + route.ResolveWith(s.remoteLinkAddr) + // Get the seqnum from the packet if the ack flag is set. seq := seqnum.Value(0) ack := seqnum.Value(0) @@ -237,7 +245,12 @@ func replyWithReset(s *segment, tos, ttl uint8) { flags |= header.TCPFlagAck ack = s.sequenceNumber.Add(s.logicalLen()) } - sendTCP(&s.route, tcpFields{ + + if ttl == 0 { + ttl = route.DefaultTTL() + } + + return sendTCP(&route, tcpFields{ id: s.id, ttl: ttl, tos: tos, diff --git a/pkg/tcpip/transport/tcp/sack_scoreboard.go b/pkg/tcpip/transport/tcp/sack_scoreboard.go index 7ef2df377..833a7b470 100644 --- a/pkg/tcpip/transport/tcp/sack_scoreboard.go +++ b/pkg/tcpip/transport/tcp/sack_scoreboard.go @@ -164,7 +164,7 @@ func (s *SACKScoreboard) IsSACKED(r header.SACKBlock) bool { return found } -// Dump prints the state of the scoreboard structure. +// String returns human-readable state of the scoreboard structure. func (s *SACKScoreboard) String() string { var str strings.Builder str.WriteString("SACKScoreboard: {") diff --git a/pkg/tcpip/transport/tcp/segment.go b/pkg/tcpip/transport/tcp/segment.go index 1f9c5cf50..2091989cc 100644 --- a/pkg/tcpip/transport/tcp/segment.go +++ b/pkg/tcpip/transport/tcp/segment.go @@ -19,6 +19,7 @@ import ( "sync/atomic" "time" + "gvisor.dev/gvisor/pkg/tcpip" "gvisor.dev/gvisor/pkg/tcpip/buffer" "gvisor.dev/gvisor/pkg/tcpip/header" "gvisor.dev/gvisor/pkg/tcpip/seqnum" @@ -45,9 +46,18 @@ type segment struct { ep *endpoint qFlags queueFlags id stack.TransportEndpointID `state:"manual"` - route stack.Route `state:"manual"` - data buffer.VectorisedView `state:".(buffer.VectorisedView)"` - hdr header.TCP + + // TODO(gvisor.dev/issue/4417): Hold a stack.PacketBuffer instead of + // individual members for link/network packet info. + srcAddr tcpip.Address + dstAddr tcpip.Address + netProto tcpip.NetworkProtocolNumber + nicID tcpip.NICID + remoteLinkAddr tcpip.LinkAddress + + data buffer.VectorisedView `state:".(buffer.VectorisedView)"` + + hdr header.TCP // views is used as buffer for data when its length is large // enough to store a VectorisedView. views [8]buffer.View `state:"nosave"` @@ -76,11 +86,16 @@ type segment struct { acked bool } -func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { +func newIncomingSegment(id stack.TransportEndpointID, pkt *stack.PacketBuffer) *segment { + netHdr := pkt.Network() s := &segment{ - refCnt: 1, - id: id, - route: r.Clone(), + refCnt: 1, + id: id, + srcAddr: netHdr.SourceAddress(), + dstAddr: netHdr.DestinationAddress(), + netProto: pkt.NetworkProtocolNumber, + nicID: pkt.NICID, + remoteLinkAddr: pkt.SourceLinkAddress(), } s.data = pkt.Data.Clone(s.views[:]) s.hdr = header.TCP(pkt.TransportHeader().View()) @@ -88,11 +103,10 @@ func newSegment(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketB return s } -func newSegmentFromView(r *stack.Route, id stack.TransportEndpointID, v buffer.View) *segment { +func newOutgoingSegment(id stack.TransportEndpointID, v buffer.View) *segment { s := &segment{ refCnt: 1, id: id, - route: r.Clone(), } s.rcvdTime = time.Now() if len(v) != 0 { @@ -110,7 +124,9 @@ func (s *segment) clone() *segment { ackNumber: s.ackNumber, flags: s.flags, window: s.window, - route: s.route.Clone(), + netProto: s.netProto, + nicID: s.nicID, + remoteLinkAddr: s.remoteLinkAddr, viewToDeliver: s.viewToDeliver, rcvdTime: s.rcvdTime, xmitTime: s.xmitTime, @@ -160,7 +176,6 @@ func (s *segment) decRef() { panic(fmt.Sprintf("unexpected queue flag %b set for segment", s.qFlags)) } } - s.route.Release() } } @@ -198,10 +213,10 @@ func (s *segment) segMemSize() int { // // Returns boolean indicating if the parsing was successful. // -// If checksum verification is not offloaded then parse also verifies the +// If checksum verification may not be skipped, parse also verifies the // TCP checksum and stores the checksum and result of checksum verification in // the csum and csumValid fields of the segment. -func (s *segment) parse() bool { +func (s *segment) parse(skipChecksumValidation bool) bool { // h is the header followed by the payload. We check that the offset to // the data respects the following constraints: // 1. That it's at least the minimum header size; if we don't do this @@ -220,16 +235,14 @@ func (s *segment) parse() bool { s.options = []byte(s.hdr[header.TCPMinimumSize:]) s.parsedOptions = header.ParseTCPOptions(s.options) - // Query the link capabilities to decide if checksum validation is - // required. verifyChecksum := true - if s.route.Capabilities()&stack.CapabilityRXChecksumOffload != 0 { + if skipChecksumValidation { s.csumValid = true verifyChecksum = false } if verifyChecksum { s.csum = s.hdr.Checksum() - xsum := s.route.PseudoHeaderChecksum(ProtocolNumber, uint16(s.data.Size()+len(s.hdr))) + xsum := header.PseudoHeaderChecksum(ProtocolNumber, s.srcAddr, s.dstAddr, uint16(s.data.Size()+len(s.hdr))) xsum = s.hdr.CalculateChecksum(xsum) xsum = header.ChecksumVV(s.data, xsum) s.csumValid = xsum == 0xffff diff --git a/pkg/tcpip/transport/tcp/snd.go b/pkg/tcpip/transport/tcp/snd.go index 6fa8d63cd..ab5fa4fb7 100644 --- a/pkg/tcpip/transport/tcp/snd.go +++ b/pkg/tcpip/transport/tcp/snd.go @@ -1285,6 +1285,10 @@ func (s *sender) checkDuplicateAck(seg *segment) (rtx bool) { // See: https://tools.ietf.org/html/draft-ietf-tcpm-rack-08#section-7.2 // steps 2 and 3. func (s *sender) walkSACK(rcvdSeg *segment) { + if len(rcvdSeg.parsedOptions.SACKBlocks) == 0 { + return + } + // Sort the SACK blocks. The first block is the most recent unacked // block. The following blocks can be in arbitrary order. sackBlocks := make([]header.SACKBlock, len(rcvdSeg.parsedOptions.SACKBlocks)) diff --git a/pkg/tcpip/transport/tcp/tcp_test.go b/pkg/tcpip/transport/tcp/tcp_test.go index a7149efd0..5f05608e2 100644 --- a/pkg/tcpip/transport/tcp/tcp_test.go +++ b/pkg/tcpip/transport/tcp/tcp_test.go @@ -5131,6 +5131,7 @@ func TestKeepalive(t *testing.T) { } func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + t.Helper() // Send a SYN request. irs = seqnum.Value(789) c.SendPacket(nil, &context.Headers{ @@ -5175,6 +5176,7 @@ func executeHandshake(t *testing.T, c *context.Context, srcPort uint16, synCooki } func executeV6Handshake(t *testing.T, c *context.Context, srcPort uint16, synCookieInUse bool) (irs, iss seqnum.Value) { + t.Helper() // Send a SYN request. irs = seqnum.Value(789) c.SendV6Packet(nil, &context.Headers{ @@ -5238,13 +5240,14 @@ func TestListenBacklogFull(t *testing.T) { // Test acceptance. // Start listening. - listenBacklog := 2 + listenBacklog := 10 if err := c.EP.Listen(listenBacklog); err != nil { t.Fatalf("Listen failed: %s", err) } - for i := 0; i < listenBacklog; i++ { - executeHandshake(t, c, context.TestPort+uint16(i), false /*synCookieInUse */) + lastPortOffset := uint16(0) + for ; int(lastPortOffset) < listenBacklog; lastPortOffset++ { + executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) } time.Sleep(50 * time.Millisecond) @@ -5252,7 +5255,7 @@ func TestListenBacklogFull(t *testing.T) { // Now execute send one more SYN. The stack should not respond as the backlog // is full at this point. c.SendPacket(nil, &context.Headers{ - SrcPort: context.TestPort + 2, + SrcPort: context.TestPort + uint16(lastPortOffset), DstPort: context.StackPort, Flags: header.TCPFlagSyn, SeqNum: seqnum.Value(789), @@ -5293,7 +5296,7 @@ func TestListenBacklogFull(t *testing.T) { } // Now a new handshake must succeed. - executeHandshake(t, c, context.TestPort+2, false /*synCookieInUse */) + executeHandshake(t, c, context.TestPort+lastPortOffset, false /*synCookieInUse */) newEP, _, err := c.EP.Accept(nil) if err == tcpip.ErrWouldBlock { @@ -6722,6 +6725,13 @@ func TestTCPTimeWaitNewSyn(t *testing.T) { c.CheckNoPacketTimeout("unexpected packet received in response to SYN", 1*time.Second) + // drain any older notifications from the notification channel before attempting + // 2nd connection. + select { + case <-ch: + default: + } + // Send a SYN request w/ sequence number higher than // the highest sequence number sent. iss = seqnum.Value(792) diff --git a/pkg/tcpip/transport/tcp/testing/context/context.go b/pkg/tcpip/transport/tcp/testing/context/context.go index 4d7847142..f791f8f13 100644 --- a/pkg/tcpip/transport/tcp/testing/context/context.go +++ b/pkg/tcpip/transport/tcp/testing/context/context.go @@ -112,6 +112,18 @@ type Headers struct { TCPOpts []byte } +// Options contains options for creating a new test context. +type Options struct { + // EnableV4 indicates whether IPv4 should be enabled. + EnableV4 bool + + // EnableV6 indicates whether IPv4 should be enabled. + EnableV6 bool + + // MTU indicates the maximum transmission unit on the link layer. + MTU uint32 +} + // Context provides an initialized Network stack and a link layer endpoint // for use in TCP tests. type Context struct { @@ -154,10 +166,30 @@ type Context struct { // New allocates and initializes a test context containing a new // stack and a link-layer endpoint. func New(t *testing.T, mtu uint32) *Context { - s := stack.New(stack.Options{ - NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol, ipv6.NewProtocol}, - TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + return NewWithOpts(t, Options{ + EnableV4: true, + EnableV6: true, + MTU: mtu, }) +} + +// NewWithOpts allocates and initializes a test context containing a new +// stack and a link-layer endpoint with specific options. +func NewWithOpts(t *testing.T, opts Options) *Context { + if opts.MTU == 0 { + panic("MTU must be greater than 0") + } + + stackOpts := stack.Options{ + TransportProtocols: []stack.TransportProtocolFactory{tcp.NewProtocol}, + } + if opts.EnableV4 { + stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv4.NewProtocol) + } + if opts.EnableV6 { + stackOpts.NetworkProtocols = append(stackOpts.NetworkProtocols, ipv6.NewProtocol) + } + s := stack.New(stackOpts) const sendBufferSize = 1 << 20 // 1 MiB const recvBufferSize = 1 << 20 // 1 MiB @@ -182,50 +214,55 @@ func New(t *testing.T, mtu uint32) *Context { // Some of the congestion control tests send up to 640 packets, we so // set the channel size to 1000. - ep := channel.New(1000, mtu, "") + ep := channel.New(1000, opts.MTU, "") wep := stack.LinkEndpoint(ep) if testing.Verbose() { wep = sniffer.New(ep) } - opts := stack.NICOptions{Name: "nic1"} - if err := s.CreateNICWithOptions(1, wep, opts); err != nil { + nicOpts := stack.NICOptions{Name: "nic1"} + if err := s.CreateNICWithOptions(1, wep, nicOpts); err != nil { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts, err) } - wep2 := stack.LinkEndpoint(channel.New(1000, mtu, "")) + wep2 := stack.LinkEndpoint(channel.New(1000, opts.MTU, "")) if testing.Verbose() { - wep2 = sniffer.New(channel.New(1000, mtu, "")) + wep2 = sniffer.New(channel.New(1000, opts.MTU, "")) } opts2 := stack.NICOptions{Name: "nic2"} if err := s.CreateNICWithOptions(2, wep2, opts2); err != nil { t.Fatalf("CreateNICWithOptions(_, _, %+v) failed: %v", opts2, err) } - v4ProtocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv4.ProtocolNumber, - AddressWithPrefix: StackAddrWithPrefix, - } - if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) - } - - v6ProtocolAddr := tcpip.ProtocolAddress{ - Protocol: ipv6.ProtocolNumber, - AddressWithPrefix: StackV6AddrWithPrefix, - } - if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { - t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) - } + var routeTable []tcpip.Route - s.SetRouteTable([]tcpip.Route{ - { + if opts.EnableV4 { + v4ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv4.ProtocolNumber, + AddressWithPrefix: StackAddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v4ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v4ProtocolAddr, err) + } + routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv4EmptySubnet, NIC: 1, - }, - { + }) + } + + if opts.EnableV6 { + v6ProtocolAddr := tcpip.ProtocolAddress{ + Protocol: ipv6.ProtocolNumber, + AddressWithPrefix: StackV6AddrWithPrefix, + } + if err := s.AddProtocolAddress(1, v6ProtocolAddr); err != nil { + t.Fatalf("AddProtocolAddress(1, %#v): %s", v6ProtocolAddr, err) + } + routeTable = append(routeTable, tcpip.Route{ Destination: header.IPv6EmptySubnet, NIC: 1, - }, - }) + }) + } + + s.SetRouteTable(routeTable) return &Context{ t: t, @@ -373,6 +410,9 @@ func (c *Context) SendICMPPacket(typ header.ICMPv4Type, code header.ICMPv4Code, const icmpv4VariableHeaderOffset = 4 copy(icmp[icmpv4VariableHeaderOffset:], p1) copy(icmp[header.ICMPv4PayloadOffset:], p2) + icmp.SetChecksum(0) + checksum := ^header.Checksum(icmp, 0 /* initial */) + icmp.SetChecksum(checksum) // Inject packet. pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{ diff --git a/pkg/tcpip/transport/udp/endpoint.go b/pkg/tcpip/transport/udp/endpoint.go index d57ed5d79..9bcb918bb 100644 --- a/pkg/tcpip/transport/udp/endpoint.go +++ b/pkg/tcpip/transport/udp/endpoint.go @@ -487,6 +487,11 @@ func (e *endpoint) write(p tcpip.Payloader, opts tcpip.WriteOptions) (int64, <-c nicID = e.BindNICID } + if to.Port == 0 { + // Port 0 is an invalid port to send to. + return 0, nil, tcpip.ErrInvalidEndpointState + } + dst, netProto, err := e.checkV4MappedLocked(*to) if err != nil { return 0, nil, err @@ -895,6 +900,9 @@ func (e *endpoint) GetSockOptBool(opt tcpip.SockOptBool) (bool, *tcpip.Error) { return v, nil + case tcpip.AcceptConnOption: + return false, nil + default: return false, tcpip.ErrUnknownProtocolOption } @@ -1009,7 +1017,7 @@ func sendUDP(r *stack.Route, data buffer.VectorisedView, localPort, remotePort u // On IPv4, UDP checksum is optional, and a zero value indicates the // transmitter skipped the checksum generation (RFC768). // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). - if r.Capabilities()&stack.CapabilityTXChecksumOffload == 0 && + if r.RequiresTXTransportChecksum() && (!noChecksum || r.NetProto == header.IPv6ProtocolNumber) { xsum := r.PseudoHeaderChecksum(ProtocolNumber, length) for _, v := range data.Views() { @@ -1366,6 +1374,12 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { e.rcvMu.Unlock() } + e.lastErrorMu.Lock() + hasError := e.lastError != nil + e.lastErrorMu.Unlock() + if hasError { + result |= waiter.EventErr + } return result } @@ -1373,10 +1387,11 @@ func (e *endpoint) Readiness(mask waiter.EventMask) waiter.EventMask { // On IPv4, UDP checksum is optional, and a zero value means the transmitter // omitted the checksum generation (RFC768). // On IPv6, UDP checksum is not optional (RFC2460 Section 8.1). -func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) bool { - if r.Capabilities()&stack.CapabilityRXChecksumOffload == 0 && - (hdr.Checksum() != 0 || r.NetProto == header.IPv6ProtocolNumber) { - xsum := r.PseudoHeaderChecksum(ProtocolNumber, hdr.Length()) +func verifyChecksum(hdr header.UDP, pkt *stack.PacketBuffer) bool { + if !pkt.RXTransportChecksumValidated && + (hdr.Checksum() != 0 || pkt.NetworkProtocolNumber == header.IPv6ProtocolNumber) { + netHdr := pkt.Network() + xsum := header.PseudoHeaderChecksum(ProtocolNumber, netHdr.DestinationAddress(), netHdr.SourceAddress(), hdr.Length()) for _, v := range pkt.Data.Views() { xsum = header.Checksum(v, xsum) } @@ -1387,7 +1402,7 @@ func verifyChecksum(r *stack.Route, hdr header.UDP, pkt *stack.PacketBuffer) boo // HandlePacket is called by the stack when new packets arrive to this transport // endpoint. -func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) { +func (e *endpoint) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) { // Get the header then trim it from the view. hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { @@ -1397,7 +1412,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk return } - if !verifyChecksum(r, hdr, pkt) { + if !verifyChecksum(hdr, pkt) { // Checksum Error. e.stack.Stats().UDP.ChecksumErrors.Increment() e.stats.ReceiveErrors.ChecksumErrors.Increment() @@ -1428,7 +1443,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // Push new packet into receive list and increment the buffer size. packet := &udpPacket{ senderAddress: tcpip.FullAddress{ - NIC: r.NICID(), + NIC: pkt.NICID, Addr: id.RemoteAddress, Port: header.UDP(hdr).SourcePort(), }, @@ -1438,7 +1453,7 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk e.rcvBufSize += pkt.Data.Size() // Save any useful information from the network header to the packet. - switch r.NetProto { + switch pkt.NetworkProtocolNumber { case header.IPv4ProtocolNumber: packet.tos, _ = header.IPv4(pkt.NetworkHeader().View()).TOS() case header.IPv6ProtocolNumber: @@ -1448,9 +1463,10 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk // TODO(gvisor.dev/issue/3556): r.LocalAddress may be a multicast or broadcast // address. packetInfo.LocalAddr should hold a unicast address that can be // used to respond to the incoming packet. - packet.packetInfo.LocalAddr = r.LocalAddress - packet.packetInfo.DestinationAddr = r.LocalAddress - packet.packetInfo.NIC = r.NICID() + localAddr := pkt.Network().DestinationAddress() + packet.packetInfo.LocalAddr = localAddr + packet.packetInfo.DestinationAddr = localAddr + packet.packetInfo.NIC = pkt.NICID packet.timestamp = e.stack.Clock().NowNanoseconds() e.rcvMu.Unlock() @@ -1465,14 +1481,16 @@ func (e *endpoint) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pk func (e *endpoint) HandleControlPacket(id stack.TransportEndpointID, typ stack.ControlType, extra uint32, pkt *stack.PacketBuffer) { if typ == stack.ControlPortUnreachable { e.mu.RLock() - defer e.mu.RUnlock() - if e.state == StateConnected { e.lastErrorMu.Lock() - defer e.lastErrorMu.Unlock() - e.lastError = tcpip.ErrConnectionRefused + e.lastErrorMu.Unlock() + e.mu.RUnlock() + + e.waiterQueue.Notify(waiter.EventErr) + return } + e.mu.RUnlock() } } diff --git a/pkg/tcpip/transport/udp/forwarder.go b/pkg/tcpip/transport/udp/forwarder.go index 3ae6cc221..14e4648cd 100644 --- a/pkg/tcpip/transport/udp/forwarder.go +++ b/pkg/tcpip/transport/udp/forwarder.go @@ -43,10 +43,9 @@ func NewForwarder(s *stack.Stack, handler func(*ForwarderRequest)) *Forwarder { // // This function is expected to be passed as an argument to the // stack.SetTransportProtocolHandler function. -func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { +func (f *Forwarder) HandlePacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) bool { f.handler(&ForwarderRequest{ stack: f.stack, - route: r, id: id, pkt: pkt, }) @@ -59,7 +58,6 @@ func (f *Forwarder) HandlePacket(r *stack.Route, id stack.TransportEndpointID, p // it via CreateEndpoint. type ForwarderRequest struct { stack *stack.Stack - route *stack.Route id stack.TransportEndpointID pkt *stack.PacketBuffer } @@ -72,17 +70,25 @@ func (r *ForwarderRequest) ID() stack.TransportEndpointID { // CreateEndpoint creates a connected UDP endpoint for the session request. func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, *tcpip.Error) { - ep := newEndpoint(r.stack, r.route.NetProto, queue) - if err := r.stack.RegisterTransportEndpoint(r.route.NICID(), []tcpip.NetworkProtocolNumber{r.route.NetProto}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { + netHdr := r.pkt.Network() + route, err := r.stack.FindRoute(r.pkt.NICID, netHdr.DestinationAddress(), netHdr.SourceAddress(), r.pkt.NetworkProtocolNumber, false /* multicastLoop */) + if err != nil { + return nil, err + } + route.ResolveWith(r.pkt.SourceLinkAddress()) + + ep := newEndpoint(r.stack, r.pkt.NetworkProtocolNumber, queue) + if err := r.stack.RegisterTransportEndpoint(r.pkt.NICID, []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber}, ProtocolNumber, r.id, ep, ep.portFlags, ep.bindToDevice); err != nil { ep.Close() + route.Release() return nil, err } ep.ID = r.id - ep.route = r.route.Clone() + ep.route = route ep.dstPort = r.id.RemotePort - ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.route.NetProto} - ep.RegisterNICID = r.route.NICID() + ep.effectiveNetProtos = []tcpip.NetworkProtocolNumber{r.pkt.NetworkProtocolNumber} + ep.RegisterNICID = r.pkt.NICID ep.boundPortFlags = ep.portFlags ep.state = StateConnected @@ -91,7 +97,7 @@ func (r *ForwarderRequest) CreateEndpoint(queue *waiter.Queue) (tcpip.Endpoint, ep.rcvReady = true ep.rcvMu.Unlock() - ep.HandlePacket(r.route, r.id, r.pkt) + ep.HandlePacket(r.id, r.pkt) return ep, nil } diff --git a/pkg/tcpip/transport/udp/protocol.go b/pkg/tcpip/transport/udp/protocol.go index da5b1deb2..91420edd3 100644 --- a/pkg/tcpip/transport/udp/protocol.go +++ b/pkg/tcpip/transport/udp/protocol.go @@ -78,15 +78,15 @@ func (*protocol) ParsePorts(v buffer.View) (src, dst uint16, err *tcpip.Error) { // HandleUnknownDestinationPacket handles packets that are targeted at this // protocol but don't match any existing endpoint. -func (p *protocol) HandleUnknownDestinationPacket(r *stack.Route, id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { +func (p *protocol) HandleUnknownDestinationPacket(id stack.TransportEndpointID, pkt *stack.PacketBuffer) stack.UnknownDestinationPacketDisposition { hdr := header.UDP(pkt.TransportHeader().View()) if int(hdr.Length()) > pkt.Data.Size()+header.UDPMinimumSize { - r.Stack().Stats().UDP.MalformedPacketsReceived.Increment() + p.stack.Stats().UDP.MalformedPacketsReceived.Increment() return stack.UnknownDestinationPacketMalformed } - if !verifyChecksum(r, hdr, pkt) { - r.Stack().Stats().UDP.ChecksumErrors.Increment() + if !verifyChecksum(hdr, pkt) { + p.stack.Stats().UDP.ChecksumErrors.Increment() return stack.UnknownDestinationPacketMalformed } diff --git a/pkg/tcpip/transport/udp/udp_test.go b/pkg/tcpip/transport/udp/udp_test.go index b4604ba35..fb7738dda 100644 --- a/pkg/tcpip/transport/udp/udp_test.go +++ b/pkg/tcpip/transport/udp/udp_test.go @@ -1452,6 +1452,10 @@ func (*testInterface) Enabled() bool { return true } +func (*testInterface) WritePacketToRemote(tcpip.LinkAddress, *stack.GSO, tcpip.NetworkProtocolNumber, *stack.PacketBuffer) *tcpip.Error { + return tcpip.ErrNotSupported +} + func TestTTL(t *testing.T) { for _, flow := range []testFlow{unicastV4, unicastV4in6, unicastV6, unicastV6Only, multicastV4, multicastV4in6, multicastV6, broadcast, broadcastIn6} { t.Run(fmt.Sprintf("flow:%s", flow), func(t *testing.T) { @@ -1791,7 +1795,6 @@ func TestV4UnknownDestination(t *testing.T) { // had only a minimal IP header but the ICMP sender will have allowed // for a maximally sized packet header. wantLen = header.IPv4MinimumProcessableDatagramSize - header.IPv4MaximumHeaderSize - header.ICMPv4MinimumSize - incomingHeaderLength - } // In the case of large payloads the IP packet may be truncated. Update diff --git a/pkg/unet/unet_test.go b/pkg/unet/unet_test.go index 5c4b9e8e9..a38ffc19d 100644 --- a/pkg/unet/unet_test.go +++ b/pkg/unet/unet_test.go @@ -53,40 +53,40 @@ func randomFilename() (string, error) { func TestConnectFailure(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } if _, err := Connect(name, false); err == nil { - t.Fatalf("connect was successful, expected err") + t.Fatalf("Connect was successful, expected err") } } func TestBindFailure(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) + t.Fatalf("First bind failed, got err %v expected nil", err) } defer ss.Close() if _, err = BindAndListen(name, false); err == nil { - t.Fatalf("second bind succeeded, expected non-nil err") + t.Fatalf("Second bind succeeded, expected non-nil err") } } func TestMultipleAccept(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) + t.Fatalf("First bind failed, got err %v expected nil", err) } defer ss.Close() @@ -99,7 +99,8 @@ func TestMultipleAccept(t *testing.T) { defer wg.Done() s, err := Connect(name, false) if err != nil { - t.Fatalf("connect failed, got err %v expected nil", err) + t.Errorf("Connect failed, got err %v expected nil", err) + return } s.Close() }() @@ -109,7 +110,7 @@ func TestMultipleAccept(t *testing.T) { for i := 0; i < backlog; i++ { s, err := ss.Accept() if err != nil { - t.Errorf("accept failed, got err %v expected nil", err) + t.Errorf("Accept failed, got err %v expected nil", err) continue } s.Close() @@ -119,35 +120,35 @@ func TestMultipleAccept(t *testing.T) { func TestServerClose(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("first bind failed, got err %v expected nil", err) + t.Fatalf("First bind failed, got err %v expected nil", err) } // Make sure the first close succeeds. if err := ss.Close(); err != nil { - t.Fatalf("first close failed, got err %v expected nil", err) + t.Fatalf("First close failed, got err %v expected nil", err) } // The second one should fail. if err := ss.Close(); err == nil { - t.Fatalf("second close succeeded, expected non-nil err") + t.Fatalf("Second close succeeded, expected non-nil err") } } func socketPair(t *testing.T, packet bool) (*Socket, *Socket) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } // Bind a server. ss, err := BindAndListen(name, packet) if err != nil { - t.Fatalf("error binding, got %v expected nil", err) + t.Fatalf("Error binding, got %v expected nil", err) } defer ss.Close() @@ -165,7 +166,7 @@ func socketPair(t *testing.T, packet bool) (*Socket, *Socket) { // Connect the client. client, err := Connect(name, packet) if err != nil { - t.Fatalf("error connecting, got %v expected nil", err) + t.Fatalf("Error connecting, got %v expected nil", err) } // Grab the server handle. @@ -173,7 +174,7 @@ func socketPair(t *testing.T, packet bool) (*Socket, *Socket) { case server := <-acceptSocket: return server, client case err := <-acceptErr: - t.Fatalf("accept error: %v", err) + t.Fatalf("Accept error: %v", err) } panic("unreachable") } @@ -186,17 +187,17 @@ func TestSendRecv(t *testing.T) { // Write on the client. w := client.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Read on the server. b := [][]byte{{'b'}} r := server.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } } @@ -211,17 +212,17 @@ func TestSymmetric(t *testing.T) { // Write on the server. w := server.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Read on the client. b := [][]byte{{'b'}} r := client.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } } @@ -233,13 +234,13 @@ func TestPacket(t *testing.T) { // Write on the client. w := client.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Write on the client again. w = client.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Read on the server. @@ -249,19 +250,19 @@ func TestPacket(t *testing.T) { b := [][]byte{{'b', 'b'}} r := server.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } // Do it again. r = server.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } } @@ -271,12 +272,12 @@ func TestClose(t *testing.T) { // Make sure the first close succeeds. if err := client.Close(); err != nil { - t.Fatalf("first close failed, got err %v expected nil", err) + t.Fatalf("First close failed, got err %v expected nil", err) } // The second one should fail. if err := client.Close(); err == nil { - t.Fatalf("second close succeeded, expected non-nil err") + t.Fatalf("Second close succeeded, expected non-nil err") } } @@ -294,17 +295,17 @@ func TestNonBlockingSend(t *testing.T) { // We're good. That's what we wanted. blockCount++ } else { - t.Fatalf("for client write, got n=%d err=%v, expected n=1000 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=1000 err=nil", n, err) } } } if blockCount == 1000 { // Shouldn't have _always_ blocked. - t.Fatalf("socket always blocked!") + t.Fatalf("Socket always blocked!") } else if blockCount == 0 { // Should have started blocking eventually. - t.Fatalf("socket never blocked!") + t.Fatalf("Socket never blocked!") } } @@ -319,25 +320,25 @@ func TestNonBlockingRecv(t *testing.T) { // Expected to block immediately. _, err := r.ReadVec(b) if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN { - t.Fatalf("read didn't block, got err %v expected blocking err", err) + t.Fatalf("Read didn't block, got err %v expected blocking err", err) } // Put some data in the pipe. w := server.Writer(false) if n, err := w.WriteVec(b); n != 1 || err != nil { - t.Fatalf("write failed with n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("Write failed with n=%d err=%v, expected n=1 err=nil", n, err) } // Expect it not to block. if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("read failed with n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("Read failed with n=%d err=%v, expected n=1 err=nil", n, err) } // Expect it to return a block error again. r = client.Reader(false) _, err = r.ReadVec(b) if err != syscall.EWOULDBLOCK && err != syscall.EAGAIN { - t.Fatalf("read didn't block, got err %v expected blocking err", err) + t.Fatalf("Read didn't block, got err %v expected blocking err", err) } } @@ -349,17 +350,17 @@ func TestRecvVectors(t *testing.T) { // Write on the client. w := client.Writer(true) if n, err := w.WriteVec([][]byte{{'a', 'b'}}); n != 2 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=2 err=nil", n, err) } // Read on the server. b := [][]byte{{'c'}, {'c'}} r := server.Reader(true) if n, err := r.ReadVec(b); n != 2 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=2 err=nil", n, err) } if b[0][0] != 'a' || b[1][0] != 'b' { - t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0]) + t.Fatalf("Got bad read data, got %c,%c, expected a,b", b[0][0], b[1][0]) } } @@ -371,17 +372,17 @@ func TestSendVectors(t *testing.T) { // Write on the client. w := client.Writer(true) if n, err := w.WriteVec([][]byte{{'a'}, {'b'}}); n != 2 || err != nil { - t.Fatalf("for client write, got n=%d err=%v, expected n=2 err=nil", n, err) + t.Fatalf("For client write, got n=%d err=%v, expected n=2 err=nil", n, err) } // Read on the server. b := [][]byte{{'c', 'c'}} r := server.Reader(true) if n, err := r.ReadVec(b); n != 2 || err != nil { - t.Fatalf("for server read, got n=%d err=%v, expected n=2 err=nil", n, err) + t.Fatalf("For server read, got n=%d err=%v, expected n=2 err=nil", n, err) } if b[0][0] != 'a' || b[0][1] != 'b' { - t.Fatalf("got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1]) + t.Fatalf("Got bad read data, got %c,%c, expected a,b", b[0][0], b[0][1]) } } @@ -394,23 +395,23 @@ func TestSendFDsNotEnabled(t *testing.T) { w := server.Writer(true) w.PackFDs(0, 1, 2) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for server write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For server write, got n=%d err=%v, expected n=1 err=nil", n, err) } // Read on the client, without enabling FDs. b := [][]byte{{'b'}} r := client.Reader(true) if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } // Make sure the FDs are not received. fds, err := r.ExtractFDs() if len(fds) != 0 || err != nil { - t.Fatalf("got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err) + t.Fatalf("Got fds=%v err=%v, expected len(fds)=0 err=nil", fds, err) } } @@ -418,7 +419,7 @@ func sendFDs(t *testing.T, s *Socket, fds []int) { w := s.Writer(true) w.PackFDs(fds...) if n, err := w.WriteVec([][]byte{{'a'}}); n != 1 || err != nil { - t.Fatalf("for write, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For write, got n=%d err=%v, expected n=1 err=nil", n, err) } } @@ -428,7 +429,7 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) { // Count the number of FDs. preEntries, err := ioutil.ReadDir("/proc/self/fd") if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) + t.Fatalf("Can't readdir, got err %v expected nil", err) } // Read on the client. @@ -438,31 +439,31 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) { r.EnableFDs(enableSize) } if n, err := r.ReadVec(b); n != 1 || err != nil { - t.Fatalf("for client read, got n=%d err=%v, expected n=1 err=nil", n, err) + t.Fatalf("For client read, got n=%d err=%v, expected n=1 err=nil", n, err) } if b[0][0] != 'a' { - t.Fatalf("got bad read data, got %c, expected a", b[0][0]) + t.Fatalf("Got bad read data, got %c, expected a", b[0][0]) } // Count the new number of FDs. postEntries, err := ioutil.ReadDir("/proc/self/fd") if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) + t.Fatalf("Can't readdir, got err %v expected nil", err) } if len(preEntries)+expected != len(postEntries) { - t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries)) + t.Errorf("Process fd count isn't right, expected %d got %d", len(preEntries)+expected, len(postEntries)) } // Make sure the FDs are there. fds, err := r.ExtractFDs() if len(fds) != expected || err != nil { - t.Fatalf("got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected) + t.Fatalf("Got fds=%v err=%v, expected len(fds)=%d err=nil", fds, err, expected) } // Make sure they are different from the originals. for i := 0; i < len(fds); i++ { if fds[i] == origFDs[i] { - t.Errorf("got original fd for index %d, expected different", i) + t.Errorf("Got original fd for index %d, expected different", i) } } @@ -480,10 +481,10 @@ func recvFDs(t *testing.T, s *Socket, enableSize int, origFDs []int) { // Make sure the count is back to normal. finalEntries, err := ioutil.ReadDir("/proc/self/fd") if err != nil { - t.Fatalf("can't readdir, got err %v expected nil", err) + t.Fatalf("Can't readdir, got err %v expected nil", err) } if len(finalEntries) != len(preEntries) { - t.Errorf("process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries)) + t.Errorf("Process fd count isn't right, expected %d got %d", len(preEntries), len(finalEntries)) } } @@ -567,7 +568,7 @@ func TestGetPeerCred(t *testing.T) { } if got, err := client.GetPeerCred(); err != nil || !reflect.DeepEqual(got, want) { - t.Errorf("got GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil) + t.Errorf("GetPeerCred() = %v, %v, want = %+v, %+v", got, err, want, nil) } } @@ -594,53 +595,53 @@ func TestGetPeerCredFailure(t *testing.T) { want := "bad file descriptor" if _, err := s.GetPeerCred(); err == nil || err.Error() != want { - t.Errorf("got s.GetPeerCred() = %v, want = %s", err, want) + t.Errorf("s.GetPeerCred() = %v, want = %s", err, want) } } func TestAcceptClosed(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) + t.Fatalf("Bind failed, got err %v expected nil", err) } if err := ss.Close(); err != nil { - t.Fatalf("close failed, got err %v expected nil", err) + t.Fatalf("Close failed, got err %v expected nil", err) } if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) + t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err) } } func TestCloseAfterAcceptStart(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) + t.Fatalf("Bind failed, got err %v expected nil", err) } wg := sync.WaitGroup{} wg.Add(1) go func() { + defer wg.Done() time.Sleep(50 * time.Millisecond) if err := ss.Close(); err != nil { - t.Fatalf("close failed, got err %v expected nil", err) + t.Errorf("Close failed, got err %v expected nil", err) } - wg.Done() }() if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) + t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err) } wg.Wait() @@ -649,28 +650,28 @@ func TestCloseAfterAcceptStart(t *testing.T) { func TestReleaseAfterAcceptStart(t *testing.T) { name, err := randomFilename() if err != nil { - t.Fatalf("unable to generate file, got err %v expected nil", err) + t.Fatalf("Unable to generate file, got err %v expected nil", err) } ss, err := BindAndListen(name, false) if err != nil { - t.Fatalf("bind failed, got err %v expected nil", err) + t.Fatalf("Bind failed, got err %v expected nil", err) } wg := sync.WaitGroup{} wg.Add(1) go func() { + defer wg.Done() time.Sleep(50 * time.Millisecond) fd, err := ss.Release() if err != nil { - t.Fatalf("Release failed, got err %v expected nil", err) + t.Errorf("Release failed, got err %v expected nil", err) } syscall.Close(fd) - wg.Done() }() if _, err := ss.Accept(); err == nil { - t.Errorf("accept on closed SocketServer, got err %v, want != nil", err) + t.Errorf("Accept on closed SocketServer, got err %v, want != nil", err) } wg.Wait() @@ -688,7 +689,7 @@ func TestControlMessage(t *testing.T) { cm.PackFDs(want...) got, err := cm.ExtractFDs() if err != nil || !reflect.DeepEqual(got, want) { - t.Errorf("got cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil) + t.Errorf("cm.ExtractFDs() = %v, %v, want = %v, %v", got, err, want, nil) } } } @@ -705,11 +706,13 @@ func benchmarkSendRecv(b *testing.B, packet bool) { for i := 0; i < b.N; i++ { n, err := server.Read(buf) if n != 1 || err != nil { - b.Fatalf("server.Read: got (%d, %v), wanted (1, nil)", n, err) + b.Errorf("server.Read: got (%d, %v), wanted (1, nil)", n, err) + return } n, err = server.Write(buf) if n != 1 || err != nil { - b.Fatalf("server.Write: got (%d, %v), wanted (1, nil)", n, err) + b.Errorf("server.Write: got (%d, %v), wanted (1, nil)", n, err) + return } } }() diff --git a/pkg/waiter/waiter.go b/pkg/waiter/waiter.go index 67a950444..08519d986 100644 --- a/pkg/waiter/waiter.go +++ b/pkg/waiter/waiter.go @@ -168,7 +168,7 @@ func NewChannelEntry(c chan struct{}) (Entry, chan struct{}) { // // +stateify savable type Queue struct { - list waiterList `state:"zerovalue"` + list waiterList mu sync.RWMutex `state:"nosave"` } |